Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRN2 Meshes and Configurations #916

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion axlearn/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Config(ConfigBase):
from collections import defaultdict
from collections.abc import Collection, Iterable
from functools import cache
from typing import Any, Callable, Generic, Optional, TypeVar, Union
from typing import Any, Callable, Generic, NamedTuple, Optional, Sequence, TypeVar, Union

# attr provides similar features as Python dataclass. Unlike
# dataclass, however, it provides a richer set of features to regulate
Expand Down Expand Up @@ -394,6 +394,66 @@ def set(self, **kwargs):
setattr(self, k, v)
return self

class TraverseResult(NamedTuple):
"""Result of a recurisve traverse in a nested ConfigBase."""

# The parent that contains the reulting key.
parent: _ConfigBase
# The key string.
key: str

def recursive_traverse(self, key_path: Sequence[str]) -> tuple[Any, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a public method?

"""Recursively traverse the config to find the target key.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see other comment re get_recursively; also, I wonder whether we actually need recursion here (seems like a loop would be simpler).


Args:
key_path: A sequence of keys for indexing.

Raises:
ValueError: A key in key_path is not found.

Returns:
A tuple containing the parent object and the target key.
"""
target_key = key_path[0]
if target_key not in self:
raise ValueError(f"{target_key} is not found in {self}.")

if len(key_path) == 1:
# Completed traversal
return self.TraverseResult(parent=self, key=key_path[0])

# Continue searching recursively
value = getattr(self, target_key)
return value.recursive_traverse(key_path[1:])

def get_recursively(self, key_path: Sequence[str]) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_recursively(self, key_path: Sequence[str]) -> Any:
def get_recursively(self, path: Sequence[str]) -> Any:

"""Recursively find the target key in the config and return its value.

Args:
key_path: A sequence of keys for indexing to get the target value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can path be empty? Maybe it can return self if path is empty?


Raises:
ValueError: A key in key_path is not found.

Returns:
value at the key_path.
"""
traverse_result = self.recursive_traverse(key_path)
return getattr(traverse_result.parent, traverse_result.key)

def set_recursively(self, key_path: Sequence[str], new_value: Any):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please name consistently with

def set_recursively(
x: NestedTensor,
*,
value: Tensor,
.

Suggested change
def set_recursively(self, key_path: Sequence[str], new_value: Any):
def set_recursively(self, path: Sequence[str], *, value: Any):

"""Recursively find the target key in the config and set its value.

Args:
key_path: A sequence of keys for indexing to set the target value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can path be empty?

new_value: New value to replace the target value.

Raises:
ValueError: A key in key_path is not found.
"""
traverse_result = self.recursive_traverse(key_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do something like:

Suggested change
traverse_result = self.recursive_traverse(key_path)
if not path:
raise ValueError(...)
parent = self.get_recursively(path[:-1])

setattr(traverse_result.parent, traverse_result.key, new_value)

def clone(self, **kwargs):
"""Returns a clone of the original config with the optional keyword overrides.

Expand Down
54 changes: 54 additions & 0 deletions axlearn/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,60 @@ def set(self, **kwargs):
self.assertEqual(123, cfg_clone.a)
self.assertEqual("default", cfg_clone.b)

def test_get_recursively(self):
@config_class
class NestedConfig(ConfigBase):
"""A dummy config."""

value: int = 0

@config_class
class TestConfig(ConfigBase):
"""Another dummy config that has a nested config."""

nested: NestedConfig = NestedConfig()
value: int = 1

cfg = TestConfig()

# Test getting nested value.
self.assertEqual(cfg.get_recursively(["nested", "value"]), 0)

# Test getting top-level value.
self.assertEqual(cfg.get_recursively(["value"]), 1)

# Test getting non-existent value.
with self.assertRaises(ValueError):
cfg.get_recursively(["non_existent"])

def test_set_recursively(self):
@config_class
class NestedConfig(ConfigBase):
"""A dummy config."""

value: int = 0

@config_class
class TestConfig(ConfigBase):
"""Another dummy config that has a nested config."""

nested: NestedConfig = NestedConfig()
value: int = 1

cfg = TestConfig()

# Test setting nested value.
cfg.set_recursively(["nested", "value"], 10)
self.assertEqual(cfg.nested.value, 10)

# Test setting top-level value.
cfg.set_recursively(["value"], 5)
self.assertEqual(cfg.value, 5)

# Test setting non-existent value.
with self.assertRaises(ValueError):
cfg.set_recursively(["non_existent"], 20)


if __name__ == "__main__":
absltest.main()
121 changes: 108 additions & 13 deletions axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
REQUIRED,
ConfigModifier,
ConfigOr,
Configurable,
Required,
config_class,
maybe_instantiate,
)
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import HybridMeshShape, MeshShape
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec


class GradientAccumulationModifier(ConfigModifier):
Expand Down Expand Up @@ -100,18 +101,8 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""

for module_name, remat_spec in self._remat_policies.items():
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
curr_module = getattr(curr_module, target_module)
# Here we assume all modules have remat_spec attribute.
if not hasattr(curr_module, "remat_spec"):
raise ValueError(f"{curr_module} does not have remat_spec attribute")
curr_module.remat_spec = remat_spec
cfg.set_recursively(module_name.split(".") + ["remat_spec"], remat_spec)

return cfg


Expand Down Expand Up @@ -146,6 +137,110 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
return cfg


class ModelConfigModifier(ConfigModifier):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which part of this class is specific to model? It seems to take generic modifications?

"""Update the model config for the trainer config."""

@config_class
class Config(ConfigModifier.Config):
"""Configure ModelConfigModifier.

Attributes:
model_cfg_modifications: A mapping from module path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outdated?

(e.g. `model.decoder.transformer.layer`) to a Config.
"""

target_config: Required[str] = REQUIRED
modification: Required[Configurable.Config] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
self._target_config = self.config.target_config
self._modification = self.config.modification

def _merge_configs(
self, target_cfg: Configurable.Config, found_module: Configurable.Config
) -> Configurable.Config:
"""Merge configurations from the config being replaced on a best effort basis.
apoorvtintin marked this conversation as resolved.
Show resolved Hide resolved

Merge Rules:
- Klass is not changed, use target cfg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Klass is not changed, use target cfg
- Klass is not changed, use target cfg.

Please end all sentences with punctuations.

- If field exists in both then use from class being replaced
- Otherwise keep the value from target_cfg

Args:
target_cfg: configuration that will replace found_module.
found_module: existing configuration whose class will be replaced
Comment on lines +171 to +172
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
target_cfg: configuration that will replace found_module.
found_module: existing configuration whose class will be replaced
target_cfg: Configuration that will replace found_module.
found_module: Existing configuration whose class will be replaced

but it's confguration will be merged with target_cfg.

Returns:
The modified config.

"""
for key in target_cfg.keys():
if key == "klass":
continue
elif hasattr(found_module, key) and hasattr(target_cfg, key):
setattr(target_cfg, key, getattr(found_module, key))
return target_cfg

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the model config of the specified modules.

Args:
cfg: The trainer config to be modified.

Raises:
ValueError: The target module is not found.

Returns:
The modified trainer config.
"""

found_module = cfg.get_recursively(self._target_config.split("."))
self._modification = self._merge_configs(self._modification, found_module)
cfg.set_recursively(self._target_config.split("."), self._modification)
return cfg


class PartitionSpecModifier(ConfigModifier):
"""Update the partition spec attribute for the specified modules."""

@config_class
class Config(ConfigModifier.Config):
"""Configure PartitionSpecModifier.

Attributes:
partition_specs: A nested mapping from module path
(e.g. `model.decoder.transformer.layer`) to another
mapping of model attribute to PartitionSpec.
"""

partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
self._attribute_dicts = self.config.partition_specs

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Update the partition_spec attributes for the specified modules.

Args:
cfg: The trainer config to be modified.

Raises:
ValueError: The target module is not found.
ValueError: The partition_spec attribute is not found.

Returns:
The modified trainer config.
"""
for module_name, partition_spec_dict in self._attribute_dicts.items():
for partition_spec_name, partition_spec in partition_spec_dict.items():
cfg.set_recursively(module_name.split(".") + [partition_spec_name], partition_spec)

return cfg


class ChainConfigModifier(ConfigModifier):
"""Chain multiple config modifiers together."""

Expand Down
85 changes: 84 additions & 1 deletion axlearn/common/trainer_config_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import jax
from absl.testing import absltest

from axlearn.common import test_utils
from axlearn.common import causal_lm, test_utils
from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer
from axlearn.common.base_layer import RematSpec
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.trainer_config_modifier import (
ChainConfigModifier,
GradientAccumulationModifier,
MeshShapeModifier,
ModelConfigModifier,
PartitionSpecModifier,
RematSpecModifier,
)
from axlearn.common.trainer_test import DummyModel
Expand Down Expand Up @@ -65,6 +68,86 @@ def test_remat_policy_override(self):
_ = cfg_modifier(cfg)


class ModelConfigModifierTest(test_utils.TestCase):
def test_model_config_override(self):
cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config())
self.assertTrue(
str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config())
)

cfg_modifier = (
ModelConfigModifier.default_config()
.set(
target_config="model.decoder.transformer",
modification=RepeatedTransformerLayer.default_config(),
)
.instantiate()
)

cfg = cfg_modifier(cfg)
# The default StackedTransformerLayer should have changed to RepeatedTransformerLayer
apoorvtintin marked this conversation as resolved.
Show resolved Hide resolved
self.assertTrue(
str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config())
)
cfg_modifier = (
ModelConfigModifier.default_config()
.set(
target_config="model.decoder.unknown",
modification=RepeatedTransformerLayer.default_config(),
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
_ = cfg_modifier(cfg)


class PartitionSpecModifierTest(test_utils.TestCase):
def test_partition_spec_override(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
},
)
.instantiate()
)
cfg = cfg_modifier(cfg)
self.assertTrue(
str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")"""
)
cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
"model.unknown": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
},
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
_ = cfg_modifier(cfg)

cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {
"param_partition_spec": ("model", ("expert", "fsdp", "seq")),
"unknown_partition_spec": ("model", ("expert", "fsdp", "seq")),
},
},
)
.instantiate()
)
with self.assertRaisesRegex(ValueError, "unknown_partition_spec is not found in.*"):
_ = cfg_modifier(cfg)


class MeshShapeModifierTest(test_utils.TestCase):
def test_mesh_shape_update(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
Expand Down
Loading