Skip to content

Commit

Permalink
Add get_recursive and set_recursive to ConfigBase.
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Jan 27, 2025
1 parent c9f9ab1 commit 2225f87
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 65 deletions.
62 changes: 61 additions & 1 deletion axlearn/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Config(ConfigBase):
from collections import defaultdict
from collections.abc import Collection, Iterable
from functools import cache
from typing import Any, Callable, Generic, Optional, TypeVar, Union
from typing import Any, Callable, Generic, NamedTuple, Optional, Sequence, TypeVar, Union

# attr provides similar features as Python dataclass. Unlike
# dataclass, however, it provides a richer set of features to regulate
Expand Down Expand Up @@ -394,6 +394,66 @@ def set(self, **kwargs):
setattr(self, k, v)
return self

class TraverseResult(NamedTuple):
"""Result of a recurisve traverse in a nested ConfigBase."""

# The parent that contains the reulting key.
parent: _ConfigBase
# The key string.
key: str

def recursive_traverse(self, key_path: Sequence[str]) -> tuple[Any, str]:
"""Recursively traverse the config to find the target key.
Args:
key_path: A sequence of keys for indexing.
Raises:
ValueError: A key in key_path is not found.
Returns:
A tuple containing the parent object and the target key.
"""
target_key = key_path[0]
if target_key not in self:
raise ValueError(f"{target_key} is not found in {self}.")

if len(key_path) == 1:
# Completed traversal
return self.TraverseResult(parent=self, key=key_path[0])

# Continue searching recursively
value = getattr(self, target_key)
return value.recursive_traverse(key_path[1:])

def get_recursively(self, key_path: Sequence[str]) -> Any:
"""Recursively find the target key in the config and return its value.
Args:
key_path: A sequence of keys for indexing to get the target value.
Raises:
ValueError: A key in key_path is not found.
Returns:
value at the key_path.
"""
traverse_result = self.recursive_traverse(key_path)
return getattr(traverse_result.parent, traverse_result.key)

def set_recursively(self, key_path: Sequence[str], new_value: Any):
"""Recursively find the target key in the config and set its value.
Args:
key_path: A sequence of keys for indexing to set the target value.
new_value: New value to replace the target value.
Raises:
ValueError: A key in key_path is not found.
"""
traverse_result = self.recursive_traverse(key_path)
setattr(traverse_result.parent, traverse_result.key, new_value)

def clone(self, **kwargs):
"""Returns a clone of the original config with the optional keyword overrides.
Expand Down
54 changes: 54 additions & 0 deletions axlearn/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,60 @@ def set(self, **kwargs):
self.assertEqual(123, cfg_clone.a)
self.assertEqual("default", cfg_clone.b)

def test_get_recursively(self):
@config_class
class NestedConfig(ConfigBase):
"""A dummy config."""

value: int = 0

@config_class
class TestConfig(ConfigBase):
"""Another dummy config that has a nested config."""

nested: NestedConfig = NestedConfig()
value: int = 1

cfg = TestConfig()

# Test getting nested value.
self.assertEqual(cfg.get_recursively(["nested", "value"]), 0)

# Test getting top-level value.
self.assertEqual(cfg.get_recursively(["value"]), 1)

# Test getting non-existent value.
with self.assertRaises(ValueError):
cfg.get_recursively(["non_existent"])

def test_set_recursively(self):
@config_class
class NestedConfig(ConfigBase):
"""A dummy config."""

value: int = 0

@config_class
class TestConfig(ConfigBase):
"""Another dummy config that has a nested config."""

nested: NestedConfig = NestedConfig()
value: int = 1

cfg = TestConfig()

# Test setting nested value.
cfg.set_recursively(["nested", "value"], 10)
self.assertEqual(cfg.nested.value, 10)

# Test setting top-level value.
cfg.set_recursively(["value"], 5)
self.assertEqual(cfg.value, 5)

# Test setting non-existent value.
with self.assertRaises(ValueError):
cfg.set_recursively(["non_existent"], 20)


if __name__ == "__main__":
absltest.main()
71 changes: 8 additions & 63 deletions axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""Defines trainer config modifiers, which will be used in model definitions."""

from typing import Dict, NamedTuple, Sequence, Union
from typing import Dict, Sequence, Union

from axlearn.common import config
from axlearn.common.base_layer import RematSpec
Expand All @@ -21,52 +21,6 @@
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec


class _FoundModule(NamedTuple):
"""Module found in recursive search of a module name in a nested configudable."""

# The module found
module: Configurable.Config
# The parent of the module found
parent_module: Configurable.Config
# Key of the found module in parent
key_in_parent: str


def _find_target_module(module_name: str, cfg: SpmdTrainer.Config) -> _FoundModule:
"""Recursively search for the target module matching module_name in provided cfg.
Args:
module_name: Name of the target module
cfg: The trainer config to be searched for module_name
Raises:
ValueError: The module_name is not found.
Returns:
A Tuple(curr_module, key_in_parent, parent_module)
curr_module: Module found
parent_module: The parent module
key_in_parent: Key in parent for the found module
"""

# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
key_in_parent = None
parent_module = None

for target_module_key in target_modules:
if not hasattr(curr_module, target_module_key):
raise ValueError(f"{target_module_key} is not found in {curr_module}.")
parent_module = curr_module
key_in_parent = target_module_key
curr_module = getattr(curr_module, target_module_key)
return _FoundModule(
module=curr_module, parent_module=parent_module, key_in_parent=key_in_parent
)


class GradientAccumulationModifier(ConfigModifier):
"""Accumulate gradients for grad_acc_steps steps."""

Expand Down Expand Up @@ -147,11 +101,8 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""

for module_name, remat_spec in self._remat_policies.items():
found_module = _find_target_module(module_name, cfg)
# Here we assume all modules have remat_spec attribute.
if not hasattr(found_module.module, "remat_spec"):
raise ValueError(f"{found_module.module} does not have remat_spec attribute")
found_module.module.remat_spec = remat_spec
cfg.set_recursively(module_name.split(".") + ["remat_spec"], remat_spec)

return cfg


Expand Down Expand Up @@ -228,8 +179,8 @@ def _merge_configs(
for key in target_cfg.keys():
if key == "klass":
continue
elif hasattr(found_module.module, key) and hasattr(target_cfg, key):
setattr(target_cfg, key, getattr(found_module.module, key))
elif hasattr(found_module, key) and hasattr(target_cfg, key):
setattr(target_cfg, key, getattr(found_module, key))
return target_cfg

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
Expand All @@ -245,10 +196,9 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
The modified trainer config.
"""

found_module = _find_target_module(self._target_config, cfg)
found_module = cfg.get_recursively(self._target_config.split("."))
self._modification = self._merge_configs(self._modification, found_module)
# Replace in the parent config
setattr(found_module.parent_module, found_module.key_in_parent, self._modification)
cfg.set_recursively(self._target_config.split("."), self._modification)
return cfg


Expand Down Expand Up @@ -285,13 +235,8 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
The modified trainer config.
"""
for module_name, partition_spec_dict in self._attribute_dicts.items():
found_module = _find_target_module(module_name, cfg)
for partition_spec_name, partition_spec in partition_spec_dict.items():
if not hasattr(found_module.module, partition_spec_name):
raise ValueError(
f"{found_module.module} does not have {partition_spec_name} attribute"
)
setattr(found_module.module, partition_spec_name, partition_spec)
cfg.set_recursively(module_name.split(".") + [partition_spec_name], partition_spec)

return cfg

Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/trainer_config_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_partition_spec_override(self):
)
.instantiate()
)
with self.assertRaisesRegex(ValueError, ".*does not have unknown_partition_spec attribute"):
with self.assertRaisesRegex(ValueError, "unknown_partition_spec is not found in.*"):
_ = cfg_modifier(cfg)


Expand Down

0 comments on commit 2225f87

Please sign in to comment.