diff --git a/axlearn/common/config.py b/axlearn/common/config.py index 030769e51..f4c64588a 100644 --- a/axlearn/common/config.py +++ b/axlearn/common/config.py @@ -70,7 +70,7 @@ class Config(ConfigBase): from collections import defaultdict from collections.abc import Collection, Iterable from functools import cache -from typing import Any, Callable, Generic, Optional, TypeVar, Union +from typing import Any, Callable, Generic, NamedTuple, Optional, Sequence, TypeVar, Union # attr provides similar features as Python dataclass. Unlike # dataclass, however, it provides a richer set of features to regulate @@ -394,6 +394,66 @@ def set(self, **kwargs): setattr(self, k, v) return self + class TraverseResult(NamedTuple): + """Result of a recurisve traverse in a nested ConfigBase.""" + + # The parent that contains the reulting key. + parent: _ConfigBase + # The key string. + key: str + + def recursive_traverse(self, key_path: Sequence[str]) -> tuple[Any, str]: + """Recursively traverse the config to find the target key. + + Args: + key_path: A sequence of keys for indexing. + + Raises: + ValueError: A key in key_path is not found. + + Returns: + A tuple containing the parent object and the target key. + """ + target_key = key_path[0] + if target_key not in self: + raise ValueError(f"{target_key} is not found in {self}.") + + if len(key_path) == 1: + # Completed traversal + return self.TraverseResult(parent=self, key=key_path[0]) + + # Continue searching recursively + value = getattr(self, target_key) + return value.recursive_traverse(key_path[1:]) + + def get_recursively(self, key_path: Sequence[str]) -> Any: + """Recursively find the target key in the config and return its value. + + Args: + key_path: A sequence of keys for indexing to get the target value. + + Raises: + ValueError: A key in key_path is not found. + + Returns: + value at the key_path. + """ + traverse_result = self.recursive_traverse(key_path) + return getattr(traverse_result.parent, traverse_result.key) + + def set_recursively(self, key_path: Sequence[str], new_value: Any): + """Recursively find the target key in the config and set its value. + + Args: + key_path: A sequence of keys for indexing to set the target value. + new_value: New value to replace the target value. + + Raises: + ValueError: A key in key_path is not found. + """ + traverse_result = self.recursive_traverse(key_path) + setattr(traverse_result.parent, traverse_result.key, new_value) + def clone(self, **kwargs): """Returns a clone of the original config with the optional keyword overrides. diff --git a/axlearn/common/config_test.py b/axlearn/common/config_test.py index afb6799d6..c19b33e85 100644 --- a/axlearn/common/config_test.py +++ b/axlearn/common/config_test.py @@ -934,6 +934,60 @@ def set(self, **kwargs): self.assertEqual(123, cfg_clone.a) self.assertEqual("default", cfg_clone.b) + def test_get_recursively(self): + @config_class + class NestedConfig(ConfigBase): + """A dummy config.""" + + value: int = 0 + + @config_class + class TestConfig(ConfigBase): + """Another dummy config that has a nested config.""" + + nested: NestedConfig = NestedConfig() + value: int = 1 + + cfg = TestConfig() + + # Test getting nested value. + self.assertEqual(cfg.get_recursively(["nested", "value"]), 0) + + # Test getting top-level value. + self.assertEqual(cfg.get_recursively(["value"]), 1) + + # Test getting non-existent value. + with self.assertRaises(ValueError): + cfg.get_recursively(["non_existent"]) + + def test_set_recursively(self): + @config_class + class NestedConfig(ConfigBase): + """A dummy config.""" + + value: int = 0 + + @config_class + class TestConfig(ConfigBase): + """Another dummy config that has a nested config.""" + + nested: NestedConfig = NestedConfig() + value: int = 1 + + cfg = TestConfig() + + # Test setting nested value. + cfg.set_recursively(["nested", "value"], 10) + self.assertEqual(cfg.nested.value, 10) + + # Test setting top-level value. + cfg.set_recursively(["value"], 5) + self.assertEqual(cfg.value, 5) + + # Test setting non-existent value. + with self.assertRaises(ValueError): + cfg.set_recursively(["non_existent"], 20) + if __name__ == "__main__": absltest.main() diff --git a/axlearn/common/trainer_config_modifier.py b/axlearn/common/trainer_config_modifier.py index d647e1a06..8b1499254 100644 --- a/axlearn/common/trainer_config_modifier.py +++ b/axlearn/common/trainer_config_modifier.py @@ -10,6 +10,7 @@ REQUIRED, ConfigModifier, ConfigOr, + Configurable, Required, config_class, maybe_instantiate, @@ -17,7 +18,7 @@ from axlearn.common.gradient_accumulation import with_minibatch_steps from axlearn.common.metrics import MetricAccumulator from axlearn.common.trainer import SpmdTrainer -from axlearn.common.utils import HybridMeshShape, MeshShape +from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec class GradientAccumulationModifier(ConfigModifier): @@ -100,18 +101,8 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: """ for module_name, remat_spec in self._remat_policies.items(): - # Here we assume x.y.z format. - # One example would be model.decoder.transformer.layer. - target_modules = module_name.split(".") - curr_module = cfg - for target_module in target_modules: - if not hasattr(curr_module, target_module): - raise ValueError(f"{target_module} is not found in {curr_module}.") - curr_module = getattr(curr_module, target_module) - # Here we assume all modules have remat_spec attribute. - if not hasattr(curr_module, "remat_spec"): - raise ValueError(f"{curr_module} does not have remat_spec attribute") - curr_module.remat_spec = remat_spec + cfg.set_recursively(module_name.split(".") + ["remat_spec"], remat_spec) + return cfg @@ -146,6 +137,110 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: return cfg +class ModelConfigModifier(ConfigModifier): + """Update the model config for the trainer config.""" + + @config_class + class Config(ConfigModifier.Config): + """Configure ModelConfigModifier. + + Attributes: + model_cfg_modifications: A mapping from module path + (e.g. `model.decoder.transformer.layer`) to a Config. + """ + + target_config: Required[str] = REQUIRED + modification: Required[Configurable.Config] = REQUIRED + + def __init__(self, cfg: Config): + super().__init__(cfg) + self._target_config = self.config.target_config + self._modification = self.config.modification + + def _merge_configs( + self, target_cfg: Configurable.Config, found_module: Configurable.Config + ) -> Configurable.Config: + """Merge configurations from the config being replaced on a best effort basis. + + Merge Rules: + - Klass is not changed, use target cfg + - If field exists in both then use from class being replaced + - Otherwise keep the value from target_cfg + + Args: + target_cfg: configuration that will replace found_module. + found_module: existing configuration whose class will be replaced + but it's confguration will be merged with target_cfg. + + Returns: + The modified config. + + """ + for key in target_cfg.keys(): + if key == "klass": + continue + elif hasattr(found_module, key) and hasattr(target_cfg, key): + setattr(target_cfg, key, getattr(found_module, key)) + return target_cfg + + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: + """Overwrite the model config of the specified modules. + + Args: + cfg: The trainer config to be modified. + + Raises: + ValueError: The target module is not found. + + Returns: + The modified trainer config. + """ + + found_module = cfg.get_recursively(self._target_config.split(".")) + self._modification = self._merge_configs(self._modification, found_module) + cfg.set_recursively(self._target_config.split("."), self._modification) + return cfg + + +class PartitionSpecModifier(ConfigModifier): + """Update the partition spec attribute for the specified modules.""" + + @config_class + class Config(ConfigModifier.Config): + """Configure PartitionSpecModifier. + + Attributes: + partition_specs: A nested mapping from module path + (e.g. `model.decoder.transformer.layer`) to another + mapping of model attribute to PartitionSpec. + """ + + partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED + + def __init__(self, cfg: Config): + super().__init__(cfg) + self._attribute_dicts = self.config.partition_specs + + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: + """Update the partition_spec attributes for the specified modules. + + Args: + cfg: The trainer config to be modified. + + Raises: + ValueError: The target module is not found. + ValueError: The partition_spec attribute is not found. + + Returns: + The modified trainer config. + """ + for module_name, partition_spec_dict in self._attribute_dicts.items(): + for partition_spec_name, partition_spec in partition_spec_dict.items(): + cfg.set_recursively(module_name.split(".") + [partition_spec_name], partition_spec) + + return cfg + + class ChainConfigModifier(ConfigModifier): """Chain multiple config modifiers together.""" diff --git a/axlearn/common/trainer_config_modifier_test.py b/axlearn/common/trainer_config_modifier_test.py index ccfe00823..8ab818588 100644 --- a/axlearn/common/trainer_config_modifier_test.py +++ b/axlearn/common/trainer_config_modifier_test.py @@ -5,13 +5,16 @@ import jax from absl.testing import absltest -from axlearn.common import test_utils +from axlearn.common import causal_lm, test_utils +from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer from axlearn.common.base_layer import RematSpec from axlearn.common.trainer import SpmdTrainer from axlearn.common.trainer_config_modifier import ( ChainConfigModifier, GradientAccumulationModifier, MeshShapeModifier, + ModelConfigModifier, + PartitionSpecModifier, RematSpecModifier, ) from axlearn.common.trainer_test import DummyModel @@ -65,6 +68,86 @@ def test_remat_policy_override(self): _ = cfg_modifier(cfg) +class ModelConfigModifierTest(test_utils.TestCase): + def test_model_config_override(self): + cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config()) + self.assertTrue( + str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config()) + ) + + cfg_modifier = ( + ModelConfigModifier.default_config() + .set( + target_config="model.decoder.transformer", + modification=RepeatedTransformerLayer.default_config(), + ) + .instantiate() + ) + + cfg = cfg_modifier(cfg) + # The default StackedTransformerLayer should have changed to RepeatedTransformerLayer + self.assertTrue( + str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config()) + ) + cfg_modifier = ( + ModelConfigModifier.default_config() + .set( + target_config="model.decoder.unknown", + modification=RepeatedTransformerLayer.default_config(), + ) + .instantiate() + ) + # Ensure that the exception is working. + with self.assertRaisesRegex(ValueError, "unknown is not found in.*"): + _ = cfg_modifier(cfg) + + +class PartitionSpecModifierTest(test_utils.TestCase): + def test_partition_spec_override(self): + cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + }, + ) + .instantiate() + ) + cfg = cfg_modifier(cfg) + self.assertTrue( + str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")""" + ) + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + "model.unknown": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + }, + ) + .instantiate() + ) + # Ensure that the exception is working. + with self.assertRaisesRegex(ValueError, "unknown is not found in.*"): + _ = cfg_modifier(cfg) + + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": { + "param_partition_spec": ("model", ("expert", "fsdp", "seq")), + "unknown_partition_spec": ("model", ("expert", "fsdp", "seq")), + }, + }, + ) + .instantiate() + ) + with self.assertRaisesRegex(ValueError, "unknown_partition_spec is not found in.*"): + _ = cfg_modifier(cfg) + + class MeshShapeModifierTest(test_utils.TestCase): def test_mesh_shape_update(self): cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt index cdff10a4c..986575d71 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt index 4393f09bf..f2abcecd7 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt index 02926944a..66842ab38 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt index e92520fa8..ec065a6d0 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt index 67b87a020..0df5feb88 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt index f01ea2bf9..332ac1994 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt index ed0018f69..00e9b49d6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt index ce42ebc30..14c8bda68 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt index 3f0c10291..b9b87ee15 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt index 3e6436f4c..ec12261cf 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt index 3e6a68d6e..a0e0898b7 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt index 00fdc9ff7..01411bd60 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt index e2670708e..5e7d9188a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt index 10ddc09e1..7d074f037 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt index 33082cd80..d9278aeee 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt index b069f70a1..857df2673 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt index 3f3a02811..f3b0c55bc 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 367001 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 367001 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,84 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt index e5d388a31..f287eb939 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 367001 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 367001 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,84 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt index 1c8dc6844..ab8c0e563 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 524288 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 1024 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 524288 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 1024 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 1024 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt index 02b00035f..916a6ac29 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 524288 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 1024 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 524288 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 1024 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 1024 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt index 14530bb04..adc730c58 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt index 3ffbaf7be..253b94df0 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt index 37676eb4b..eadeb917e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt index 8efdf10e2..21ba046cf 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt index 94c96a380..923359ff6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt @@ -202,6 +202,93 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt index f72349af7..a12719f2f 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt @@ -202,6 +202,93 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt index a3f8ac77e..f747fc568 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt @@ -202,6 +202,93 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt index c4426cac3..0ebd87ae0 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt @@ -202,6 +202,93 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt index 40d58e819..ad59e47a6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt index 98be3b833..b78f52af1 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt index 0e057f4b9..1ced561b5 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt index 8d69c9254..04bbf53df 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt index 467258bf0..5227067ef 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt index 47fb69af4..47d614372 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt index 27dc49fb1..a49c47dae 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt index f391b0abc..77bbabdde 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt index 878b7889a..1833c683c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt @@ -186,6 +186,101 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[6][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[6][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[6][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt index bd7c71f4a..347f7007a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt @@ -186,6 +186,101 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[6][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[6][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[6][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt index 5e0762544..01a1e850c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt @@ -186,6 +186,101 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[6][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[6][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[6][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt index 17ba6f233..e019219a8 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt @@ -186,6 +186,101 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[6][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[6][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[6][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index bbd769dad..6f9635ead 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -22,11 +22,13 @@ BaseStackedTransformerLayer, FusedGroupedQKVLinear, FusedQKVLinear, + GroupedQKVLinear, GroupedQueryAttention, MultiheadAttention, RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, + StackedTransformerLayer, ) from axlearn.common.base_layer import RematSpec from axlearn.common.config import config_for_function @@ -38,6 +40,8 @@ ChainConfigModifier, GradientAccumulationModifier, MeshShapeModifier, + ModelConfigModifier, + PartitionSpecModifier, RematSpecModifier, ) from axlearn.common.utils import ( @@ -151,6 +155,60 @@ def get_trainer_kwargs( rope_theta = ROPE_THETA[version] + # TRN2 specific model config modifications + trn2_model_modifications = [ + # Neuron compiler has a module to detect repeating blocks and reuse them during compilation. + # So compile time does not grow with the number of layers. + ModelConfigModifier.default_config().set( + target_config="model.decoder.transformer", + modification=StackedTransformerLayer.default_config(), + ) + ] + if version != Version.V1: + trn2_model_modifications.append( + ModelConfigModifier.default_config().set( + target_config="model.decoder.transformer.layer.self_attention.attention." + "input_linear.input_linear", + modification=GroupedQKVLinear.default_config(), + ) + ) + + trn2_partition_spec_modifications = [ + PartitionSpecModifier.default_config().set( + partition_specs={ + # Vocab parallel embeddings sharding from Megatron LM. + "model.decoder.emb.token_emb": { + "param_partition_spec": ( + "model", + ("expert", "fsdp", "seq"), + ), + "input_partition_spec": ("fsdp", None), + "output_partition_spec": ("fsdp", "model"), + "embedding_partition_spec": ("model", "fsdp"), + }, + "model.decoder.lm_head": { + "param_partition_spec": ( + "model", + ("expert", "fsdp", "seq"), + ), + }, + # Sequence parallel shardings for norms. + "model.decoder.transformer.layer.self_attention.norm": { + "input_partition_spec": ("fsdp", "model", None), + "output_partition_spec": ("fsdp", None, None), + }, + "model.decoder.transformer.layer.feed_forward.norm": { + "input_partition_spec": ("fsdp", "model", None), + "output_partition_spec": ("fsdp", None, None), + }, + "model.decoder.output_norm": { + "input_partition_spec": ("fsdp", "model", None), + "output_partition_spec": ("fsdp", None, None), + }, + }, + ), + ] + offload_dots_saveable_policy = config_for_function( extended_checkpoint_policies.offload_dots_saveable ).set(offload_src="device", offload_dst="pinned_host") @@ -204,6 +262,22 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + *trn2_model_modifications, + *trn2_partition_spec_modifications, + ], + ), + ), + ), ) elif model_size == "3B": trainer_kwargs = dict( @@ -222,6 +296,22 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + *trn2_model_modifications, + *trn2_partition_spec_modifications, + ], + ), + ), + ), ) elif model_size == "7B": trainer_kwargs = dict( @@ -335,6 +425,20 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + *trn2_model_modifications, + *trn2_partition_spec_modifications, + ], + ), + ), ), ) elif model_size == "8B": @@ -415,6 +519,20 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + *trn2_model_modifications, + *trn2_partition_spec_modifications, + ], + ), + ), ), ) elif model_size == "70B": @@ -433,7 +551,7 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=8, max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( @@ -509,6 +627,8 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) ), RematSpecModifier.default_config().set( @@ -531,6 +651,8 @@ def get_trainer_kwargs( ), } ), + *trn2_model_modifications, + *trn2_partition_spec_modifications, ], ), ),