From 4d042b7762843c7d38fe30f1d5b6791e6527183d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 9 Dec 2024 16:37:41 +0000 Subject: [PATCH 1/3] Feat (brevitas_examples/llm): correct scale init with CPU offloading --- src/brevitas/utils/python_utils.py | 10 ++++++++++ src/brevitas_examples/llm/main.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/brevitas/utils/python_utils.py b/src/brevitas/utils/python_utils.py index fa69499c7..ae8845b48 100644 --- a/src/brevitas/utils/python_utils.py +++ b/src/brevitas/utils/python_utils.py @@ -54,3 +54,13 @@ def _getattr(obj, attr): return getattr(obj, attr) return functools.reduce(_getattr, [obj] + attr.split(".")) + + +def hooked_on_a_function(function, prefunction): + + @functools.wraps(function) + def run(*args, **kwargs): + prefunction() + return function(*args, **kwargs) + + return run diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e4f6dda41..527487fed 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -3,6 +3,7 @@ import argparse from copy import deepcopy +import functools import sys from warnings import warn @@ -20,8 +21,10 @@ from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.utils import get_module +from brevitas.utils.python_utils import hooked_on_a_function from brevitas_examples.common.accelerate_utils.accelerate import offload_model from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks +from brevitas_examples.common.accelerate_utils.accelerate import update_internal_dict from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers from brevitas_examples.common.parse_utils import quant_format_validator @@ -378,9 +381,24 @@ def main(args): model = offload_model(model) + dict_hooks = dict() + + def update_params_post_init(module): + update_internal_dict(module) + + for m in model.modules(): + if hasattr(m, '_hf_hook'): + if m._hf_hook.weights_map is not None: + dict_hooks[m] = m._hf_hook.post_forward + new_funct = functools.partial(update_params_post_init, m) + m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct) + with torch.no_grad(): model(**calibration_loader[0]) + for k, v in dict_hooks.items(): + k._hf_hook.post_forward = v + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) From 7b68194a5f4af371468751d49623308a4f785219 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 17 Dec 2024 13:44:33 +0100 Subject: [PATCH 2/3] Update main.py --- src/brevitas_examples/llm/main.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 527487fed..63898d6ef 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -383,19 +383,23 @@ def main(args): dict_hooks = dict() - def update_params_post_init(module): - update_internal_dict(module) - + # When offloading to CPU + GPU, the CPU scale factors must be updated + # before we move them back to the meta device. + # If we don't, we lose the new value but the internal flag "init_done" is True, thus we will use the wrong scale. + # To do this, we attach a "hook" to the post_forward function, called before the post_forward + # The function will update the dict with the initialized scales for m in model.modules(): if hasattr(m, '_hf_hook'): if m._hf_hook.weights_map is not None: + # We store the original function to be restored later dict_hooks[m] = m._hf_hook.post_forward - new_funct = functools.partial(update_params_post_init, m) + new_funct = functools.partial(update_internal_dict, m) m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct) with torch.no_grad(): model(**calibration_loader[0]) + # We restore the original behaviour of the post-forward. for k, v in dict_hooks.items(): k._hf_hook.post_forward = v From 5190deba5047055d9f3e1e4d5035125f02871e39 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 17 Dec 2024 14:10:38 +0000 Subject: [PATCH 3/3] precommit --- src/brevitas_examples/llm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 63898d6ef..b9c2c1c4d 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -399,7 +399,7 @@ def main(args): with torch.no_grad(): model(**calibration_loader[0]) - # We restore the original behaviour of the post-forward. + # We restore the original behaviour of the post-forward. for k, v in dict_hooks.items(): k._hf_hook.post_forward = v