-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TRN2 Meshes and Configurations #916
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -70,7 +70,7 @@ class Config(ConfigBase): | |||||||||||||
from collections import defaultdict | ||||||||||||||
from collections.abc import Collection, Iterable | ||||||||||||||
from functools import cache | ||||||||||||||
from typing import Any, Callable, Generic, Optional, TypeVar, Union | ||||||||||||||
from typing import Any, Callable, Generic, NamedTuple, Optional, Sequence, TypeVar, Union | ||||||||||||||
|
||||||||||||||
# attr provides similar features as Python dataclass. Unlike | ||||||||||||||
# dataclass, however, it provides a richer set of features to regulate | ||||||||||||||
|
@@ -394,6 +394,66 @@ def set(self, **kwargs): | |||||||||||||
setattr(self, k, v) | ||||||||||||||
return self | ||||||||||||||
|
||||||||||||||
class TraverseResult(NamedTuple): | ||||||||||||||
"""Result of a recurisve traverse in a nested ConfigBase.""" | ||||||||||||||
|
||||||||||||||
# The parent that contains the reulting key. | ||||||||||||||
parent: _ConfigBase | ||||||||||||||
# The key string. | ||||||||||||||
key: str | ||||||||||||||
|
||||||||||||||
def recursive_traverse(self, key_path: Sequence[str]) -> tuple[Any, str]: | ||||||||||||||
"""Recursively traverse the config to find the target key. | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please see other comment re |
||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
key_path: A sequence of keys for indexing. | ||||||||||||||
|
||||||||||||||
Raises: | ||||||||||||||
ValueError: A key in key_path is not found. | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
A tuple containing the parent object and the target key. | ||||||||||||||
""" | ||||||||||||||
target_key = key_path[0] | ||||||||||||||
if target_key not in self: | ||||||||||||||
raise ValueError(f"{target_key} is not found in {self}.") | ||||||||||||||
|
||||||||||||||
if len(key_path) == 1: | ||||||||||||||
# Completed traversal | ||||||||||||||
return self.TraverseResult(parent=self, key=key_path[0]) | ||||||||||||||
|
||||||||||||||
# Continue searching recursively | ||||||||||||||
value = getattr(self, target_key) | ||||||||||||||
return value.recursive_traverse(key_path[1:]) | ||||||||||||||
|
||||||||||||||
def get_recursively(self, key_path: Sequence[str]) -> Any: | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
"""Recursively find the target key in the config and return its value. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
key_path: A sequence of keys for indexing to get the target value. | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can path be empty? Maybe it can return |
||||||||||||||
|
||||||||||||||
Raises: | ||||||||||||||
ValueError: A key in key_path is not found. | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
value at the key_path. | ||||||||||||||
""" | ||||||||||||||
traverse_result = self.recursive_traverse(key_path) | ||||||||||||||
return getattr(traverse_result.parent, traverse_result.key) | ||||||||||||||
|
||||||||||||||
def set_recursively(self, key_path: Sequence[str], new_value: Any): | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please name consistently with axlearn/axlearn/common/utils.py Lines 907 to 910 in a854738
Suggested change
|
||||||||||||||
"""Recursively find the target key in the config and set its value. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
key_path: A sequence of keys for indexing to set the target value. | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can path be empty? |
||||||||||||||
new_value: New value to replace the target value. | ||||||||||||||
|
||||||||||||||
Raises: | ||||||||||||||
ValueError: A key in key_path is not found. | ||||||||||||||
""" | ||||||||||||||
traverse_result = self.recursive_traverse(key_path) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we do something like:
Suggested change
|
||||||||||||||
setattr(traverse_result.parent, traverse_result.key, new_value) | ||||||||||||||
|
||||||||||||||
def clone(self, **kwargs): | ||||||||||||||
"""Returns a clone of the original config with the optional keyword overrides. | ||||||||||||||
|
||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -10,14 +10,15 @@ | |||||||||
REQUIRED, | ||||||||||
ConfigModifier, | ||||||||||
ConfigOr, | ||||||||||
Configurable, | ||||||||||
Required, | ||||||||||
config_class, | ||||||||||
maybe_instantiate, | ||||||||||
) | ||||||||||
from axlearn.common.gradient_accumulation import with_minibatch_steps | ||||||||||
from axlearn.common.metrics import MetricAccumulator | ||||||||||
from axlearn.common.trainer import SpmdTrainer | ||||||||||
from axlearn.common.utils import HybridMeshShape, MeshShape | ||||||||||
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec | ||||||||||
|
||||||||||
|
||||||||||
class GradientAccumulationModifier(ConfigModifier): | ||||||||||
|
@@ -100,18 +101,8 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: | |||||||||
""" | ||||||||||
|
||||||||||
for module_name, remat_spec in self._remat_policies.items(): | ||||||||||
# Here we assume x.y.z format. | ||||||||||
# One example would be model.decoder.transformer.layer. | ||||||||||
target_modules = module_name.split(".") | ||||||||||
curr_module = cfg | ||||||||||
for target_module in target_modules: | ||||||||||
if not hasattr(curr_module, target_module): | ||||||||||
raise ValueError(f"{target_module} is not found in {curr_module}.") | ||||||||||
curr_module = getattr(curr_module, target_module) | ||||||||||
# Here we assume all modules have remat_spec attribute. | ||||||||||
if not hasattr(curr_module, "remat_spec"): | ||||||||||
raise ValueError(f"{curr_module} does not have remat_spec attribute") | ||||||||||
curr_module.remat_spec = remat_spec | ||||||||||
cfg.set_recursively(module_name.split(".") + ["remat_spec"], remat_spec) | ||||||||||
|
||||||||||
return cfg | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -146,6 +137,110 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: | |||||||||
return cfg | ||||||||||
|
||||||||||
|
||||||||||
class ModelConfigModifier(ConfigModifier): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which part of this class is specific to model? It seems to take generic modifications? |
||||||||||
"""Update the model config for the trainer config.""" | ||||||||||
|
||||||||||
@config_class | ||||||||||
class Config(ConfigModifier.Config): | ||||||||||
"""Configure ModelConfigModifier. | ||||||||||
|
||||||||||
Attributes: | ||||||||||
model_cfg_modifications: A mapping from module path | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Outdated? |
||||||||||
(e.g. `model.decoder.transformer.layer`) to a Config. | ||||||||||
""" | ||||||||||
|
||||||||||
target_config: Required[str] = REQUIRED | ||||||||||
modification: Required[Configurable.Config] = REQUIRED | ||||||||||
|
||||||||||
def __init__(self, cfg: Config): | ||||||||||
super().__init__(cfg) | ||||||||||
self._target_config = self.config.target_config | ||||||||||
self._modification = self.config.modification | ||||||||||
|
||||||||||
def _merge_configs( | ||||||||||
self, target_cfg: Configurable.Config, found_module: Configurable.Config | ||||||||||
) -> Configurable.Config: | ||||||||||
"""Merge configurations from the config being replaced on a best effort basis. | ||||||||||
apoorvtintin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
Merge Rules: | ||||||||||
- Klass is not changed, use target cfg | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Please end all sentences with punctuations. |
||||||||||
- If field exists in both then use from class being replaced | ||||||||||
- Otherwise keep the value from target_cfg | ||||||||||
|
||||||||||
Args: | ||||||||||
target_cfg: configuration that will replace found_module. | ||||||||||
found_module: existing configuration whose class will be replaced | ||||||||||
Comment on lines
+171
to
+172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
but it's confguration will be merged with target_cfg. | ||||||||||
|
||||||||||
Returns: | ||||||||||
The modified config. | ||||||||||
|
||||||||||
""" | ||||||||||
for key in target_cfg.keys(): | ||||||||||
if key == "klass": | ||||||||||
continue | ||||||||||
elif hasattr(found_module, key) and hasattr(target_cfg, key): | ||||||||||
setattr(target_cfg, key, getattr(found_module, key)) | ||||||||||
return target_cfg | ||||||||||
|
||||||||||
def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: | ||||||||||
"""Overwrite the model config of the specified modules. | ||||||||||
|
||||||||||
Args: | ||||||||||
cfg: The trainer config to be modified. | ||||||||||
|
||||||||||
Raises: | ||||||||||
ValueError: The target module is not found. | ||||||||||
|
||||||||||
Returns: | ||||||||||
The modified trainer config. | ||||||||||
""" | ||||||||||
|
||||||||||
found_module = cfg.get_recursively(self._target_config.split(".")) | ||||||||||
self._modification = self._merge_configs(self._modification, found_module) | ||||||||||
cfg.set_recursively(self._target_config.split("."), self._modification) | ||||||||||
return cfg | ||||||||||
|
||||||||||
|
||||||||||
class PartitionSpecModifier(ConfigModifier): | ||||||||||
"""Update the partition spec attribute for the specified modules.""" | ||||||||||
|
||||||||||
@config_class | ||||||||||
class Config(ConfigModifier.Config): | ||||||||||
"""Configure PartitionSpecModifier. | ||||||||||
|
||||||||||
Attributes: | ||||||||||
partition_specs: A nested mapping from module path | ||||||||||
(e.g. `model.decoder.transformer.layer`) to another | ||||||||||
mapping of model attribute to PartitionSpec. | ||||||||||
""" | ||||||||||
|
||||||||||
partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED | ||||||||||
|
||||||||||
def __init__(self, cfg: Config): | ||||||||||
super().__init__(cfg) | ||||||||||
self._attribute_dicts = self.config.partition_specs | ||||||||||
|
||||||||||
def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: | ||||||||||
"""Update the partition_spec attributes for the specified modules. | ||||||||||
|
||||||||||
Args: | ||||||||||
cfg: The trainer config to be modified. | ||||||||||
|
||||||||||
Raises: | ||||||||||
ValueError: The target module is not found. | ||||||||||
ValueError: The partition_spec attribute is not found. | ||||||||||
|
||||||||||
Returns: | ||||||||||
The modified trainer config. | ||||||||||
""" | ||||||||||
for module_name, partition_spec_dict in self._attribute_dicts.items(): | ||||||||||
for partition_spec_name, partition_spec in partition_spec_dict.items(): | ||||||||||
cfg.set_recursively(module_name.split(".") + [partition_spec_name], partition_spec) | ||||||||||
|
||||||||||
return cfg | ||||||||||
|
||||||||||
|
||||||||||
class ChainConfigModifier(ConfigModifier): | ||||||||||
"""Chain multiple config modifiers together.""" | ||||||||||
|
||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be a public method?