From 09f13710fd9733c12c2cb6844715115f55086bf7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 17 Dec 2024 14:27:39 +0000 Subject: [PATCH] Feat (brevitas_examples/llm): correct scale init with CPU offloading (#1124) --- src/brevitas/utils/python_utils.py | 10 ++++++++++ src/brevitas_examples/llm/main.py | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/brevitas/utils/python_utils.py b/src/brevitas/utils/python_utils.py index fa69499c7..ae8845b48 100644 --- a/src/brevitas/utils/python_utils.py +++ b/src/brevitas/utils/python_utils.py @@ -54,3 +54,13 @@ def _getattr(obj, attr): return getattr(obj, attr) return functools.reduce(_getattr, [obj] + attr.split(".")) + + +def hooked_on_a_function(function, prefunction): + + @functools.wraps(function) + def run(*args, **kwargs): + prefunction() + return function(*args, **kwargs) + + return run diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e4f6dda41..b9c2c1c4d 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -3,6 +3,7 @@ import argparse from copy import deepcopy +import functools import sys from warnings import warn @@ -20,8 +21,10 @@ from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.utils import get_module +from brevitas.utils.python_utils import hooked_on_a_function from brevitas_examples.common.accelerate_utils.accelerate import offload_model from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks +from brevitas_examples.common.accelerate_utils.accelerate import update_internal_dict from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers from brevitas_examples.common.parse_utils import quant_format_validator @@ -378,9 +381,28 @@ def main(args): model = offload_model(model) + dict_hooks = dict() + + # When offloading to CPU + GPU, the CPU scale factors must be updated + # before we move them back to the meta device. + # If we don't, we lose the new value but the internal flag "init_done" is True, thus we will use the wrong scale. + # To do this, we attach a "hook" to the post_forward function, called before the post_forward + # The function will update the dict with the initialized scales + for m in model.modules(): + if hasattr(m, '_hf_hook'): + if m._hf_hook.weights_map is not None: + # We store the original function to be restored later + dict_hooks[m] = m._hf_hook.post_forward + new_funct = functools.partial(update_internal_dict, m) + m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct) + with torch.no_grad(): model(**calibration_loader[0]) + # We restore the original behaviour of the post-forward. + for k, v in dict_hooks.items(): + k._hf_hook.post_forward = v + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader)