Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRN2 Meshes and Configurations #916

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apoorvtintin
Copy link
Contributor

This PR adds meshes for TRN2/1 for Fuji models and transformer layer configuration favorable to Neuron.

Neuron supports stacked transformer and GroupedQKVLinear instead of FusedGroupedQKVLinear for Grouped Query Attention (GQA)

This is a newer version of the PR #885. This PR resolved all comments and requested changes mentioned in the linked PR.

@apoorvtintin apoorvtintin requested review from ruomingp, markblee and a team as code owners January 10, 2025 00:48
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 6b404f6 to 3f7c840 Compare January 10, 2025 00:53
@apoorvtintin
Copy link
Contributor Author

Added a ModelConfigModifier that overrides the class for a module. Allowing different model configurations based on Model size and platform.

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making such change, overall looks good. A few nit comments.

continue
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to extract a common util function named something like
def replace_module_recursive(target_modules:str, config_key: str, target_config) and make it applied to both here and RematSpecModifier

Copy link
Contributor Author

@apoorvtintin apoorvtintin Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted a helper function, let me know if this looks good

axlearn/common/trainer_config_modifier_test.py Outdated Show resolved Hide resolved
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 708fc5e to d481132 Compare January 10, 2025 07:38
@apoorvtintin
Copy link
Contributor Author

apoorvtintin commented Jan 10, 2025

Added ParameterPartitionSpecModifier for parameters to shard Embeddings in a vocab parallel manner as described in Megatron LM.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 5be50d7 to 9b10041 Compare January 10, 2025 08:10
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved

found_module, parent_module, key_in_parent = find_target_module(module_name, cfg)

# Copy configurations from the config being replaced on a best effort basis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, this behavior is not explained in the class comments. So we are not replacing but merging the configs? Maybe we should support a merge function instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the goal is to change the config to a similar module. This means most of the configuration can be reused from before. Essentially replacing the module but merging the config. Let me extract out a merge function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Abstracted out a merge function let me know if more changes are needed for this.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 9b10041 to 0f0a530 Compare January 12, 2025 07:06
@apoorvtintin
Copy link
Contributor Author

@ruomingp Thank you for the review, I have addressed all your comments, please let me know if more changes are needed.

@apoorvtintin apoorvtintin requested a review from ruomingp January 12, 2025 07:08
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Show resolved Hide resolved
self._model_cfg_modifications = cfg.model_cfg_modifications

def _merge_configs(self, target_cfg: ConfigBase, found_module: ConfigBase) -> ConfigBase:
"""Merge configurations from the config being replaced on a best effort basis.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is vague about how merging actually works. Looks like the rule is:

  • Use target_cfg.klass
  • Use the field value from found_module if the field exists in both configs (and is not "klass")
  • Otherwise keep the value from target_cfg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a more verbose comment

axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
Comment on lines 239 to 244
for module_name, model_cfg in self._model_cfg_modifications.items():
found_module = _find_target_module(module_name, cfg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In utils.py we have get_recursively and set_recursively for Nested[...]. I wonder if it will be useful to add corresponding methods to ConfigBase. Then we can do something like:

Suggested change
for module_name, model_cfg in self._model_cfg_modifications.items():
found_module = _find_target_module(module_name, cfg)
for cfg_path, cfg_modification in self._model_cfg_modifications.items():
child_cfg = cfg.get_recursively(cfg_path)
child_cfg = cfg_modification(child_cfg, path=cfg_path)
cfg.set_recursively(cfg_path, value=child_cfg)

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from c23e3b2 to 94bfff6 Compare January 15, 2025 00:10
@apoorvtintin
Copy link
Contributor Author

apoorvtintin commented Jan 15, 2025

Added a more flexible PartitionSpecModifier that can modify multiple partition_spec attributes in a single module config.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants