-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TRN2 Meshes and Configurations #916
base: main
Are you sure you want to change the base?
TRN2 Meshes and Configurations #916
Conversation
6b404f6
to
3f7c840
Compare
Added a ModelConfigModifier that overrides the class for a module. Allowing different model configurations based on Model size and platform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making such change, overall looks good. A few nit comments.
continue | ||
# Here we assume x.y.z format. | ||
# One example would be model.decoder.transformer.layer. | ||
target_modules = module_name.split(".") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try to extract a common util function named something like
def replace_module_recursive(target_modules:str, config_key: str, target_config)
and make it applied to both here and RematSpecModifier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I extracted a helper function, let me know if this looks good
708fc5e
to
d481132
Compare
Added |
5be50d7
to
9b10041
Compare
|
||
found_module, parent_module, key_in_parent = find_target_module(module_name, cfg) | ||
|
||
# Copy configurations from the config being replaced on a best effort basis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, this behavior is not explained in the class comments. So we are not replacing but merging the configs? Maybe we should support a merge function instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the goal is to change the config to a similar module. This means most of the configuration can be reused from before. Essentially replacing the module but merging the config. Let me extract out a merge function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Abstracted out a merge function let me know if more changes are needed for this.
9b10041
to
0f0a530
Compare
@ruomingp Thank you for the review, I have addressed all your comments, please let me know if more changes are needed. |
self._model_cfg_modifications = cfg.model_cfg_modifications | ||
|
||
def _merge_configs(self, target_cfg: ConfigBase, found_module: ConfigBase) -> ConfigBase: | ||
"""Merge configurations from the config being replaced on a best effort basis. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is vague about how merging actually works. Looks like the rule is:
- Use target_cfg.klass
- Use the field value from found_module if the field exists in both configs (and is not "klass")
- Otherwise keep the value from target_cfg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a more verbose comment
for module_name, model_cfg in self._model_cfg_modifications.items(): | ||
found_module = _find_target_module(module_name, cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In utils.py we have get_recursively
and set_recursively
for Nested[...]
. I wonder if it will be useful to add corresponding methods to ConfigBase. Then we can do something like:
for module_name, model_cfg in self._model_cfg_modifications.items(): | |
found_module = _find_target_module(module_name, cfg) | |
for cfg_path, cfg_modification in self._model_cfg_modifications.items(): | |
child_cfg = cfg.get_recursively(cfg_path) | |
child_cfg = cfg_modification(child_cfg, path=cfg_path) | |
cfg.set_recursively(cfg_path, value=child_cfg) |
0f0a530
to
c23e3b2
Compare
c23e3b2
to
94bfff6
Compare
Added a more flexible |
This PR adds meshes for TRN2/1 for Fuji models and transformer layer configuration favorable to Neuron.
Neuron supports stacked transformer and GroupedQKVLinear instead of FusedGroupedQKVLinear for Grouped Query Attention (GQA)
This is a newer version of the PR #885. This PR resolved all comments and requested changes mentioned in the linked PR.