From 2b21dfd69fe8d8994dd3509eae9dbd1265469409 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 14 Jan 2025 14:52:40 +0000 Subject: [PATCH] Enable quantization of parametrized layers --- src/brevitas/graph/base.py | 48 ++++++++++++- src/brevitas/graph/quantize_impl.py | 8 ++- tests/brevitas/graph/test_quantize.py | 100 ++++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index d1631f34e..3931640a7 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -109,6 +109,34 @@ def apply(self, model: Module, *model_args, **model_kwargs): return model +def _remove_parametrization_entries_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: + # Keys for values related to parametrizations + keys_to_remove = [] + # Keys/values corresponding to the original tensors before parametrizations + keys_values_to_add = [] + # Iterate over state_dict identifying the keys related to parametrizations + for key, value in state_dict.items(): + split_key = key.split(".") + # Keys for values before parametrizations have the format "prefix.parametrizations.tensor_name.original" + if len(split_key + ) >= 3 and split_key[-3] == "parametrizations" and split_key[-1] == "original": + tensor_name = split_key[-2] + # Name of dictionary entry is "prefix.tensro_name" + keys_values_to_add.append((".".join(split_key[:-3] + [tensor_name]), value)) + # Keys corresponding to the parametrizations attached to the model need to be removed + # to make sure the dictionary can be loaded with no missing/unused keys + # NOTE: For safety, an additional check could be added as this logic would not work if a model + # without parametrizations has any key containing "parametrizations" + if "parametrizations" in split_key: + keys_to_remove.append(key) + # Apply changes in-place to the state_dict + for key in keys_to_remove: + del state_dict[key] + for key, value in keys_values_to_add: + state_dict[key] = value + return state_dict + + class ModuleToModule(GraphTransform, ABC): def __init__(self, new_module_class, **kwargs): @@ -159,7 +187,25 @@ def _init_new_module(self, old_module: Module, name=None): def _replace_old_module(self, model, old_module, new_module, load_state_dict=True): replace_module(model, old_module, new_module) if load_state_dict: - new_module.load_state_dict(old_module.state_dict()) + old_module_state_dict = old_module.state_dict() + # If parametrizations are present in old_module, the state_dict needs + # to be processed beforehand + if parametrize.is_parametrized(old_module): + old_module_state_dict = _remove_parametrization_entries_state_dict( + old_module_state_dict) + # Strict can be set to True, since potential parametrizations were + # accounted for + new_module.load_state_dict(old_module_state_dict) + # If the old module is parametrized, these need to be transferred to the new module. + # The method transfer_parametrizations_and_params as it can result in parameter ties + # being broken + # NOTE: unsafe is set to True for efficiency, as the checks should have been done + # when first registering the parametrization to old_module + if parametrize.is_parametrized(old_module): + for tensor_name in old_module.parametrizations: + for param_func in old_module.parametrizations[tensor_name]: + parametrize.register_parametrization( + new_module, tensor_name, param_func, unsafe=True) class InsertModuleCallAfter(GraphTransform): diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 535f9a8f9..a4d348ab5 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn +import torch.nn.utils.parametrize as parametrize import brevitas from brevitas.graph.base import InsertModuleCallAfter @@ -511,7 +512,7 @@ def find_module( Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its Linear submodules. """ - if _module_class_name(type(model)) in layer_map.keys(): + if _module_class_name(parametrize.type_before_parametrizations(model)) in layer_map.keys(): module_to_replace.append(model) else: for name, module in model.named_children(): @@ -532,8 +533,9 @@ def layerwise_layer_handler( find_module(model, layer_map, module_to_replace, name_blacklist) rewriters = [] for module in module_to_replace: - if layer_map[_module_class_name(type(module))] is not None: - quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type(module))] + if layer_map[_module_class_name( + parametrize.type_before_parametrizations(module))] is not None: + quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))] rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs) rewriters.append(rewriter) for rewriter in rewriters: diff --git a/tests/brevitas/graph/test_quantize.py b/tests/brevitas/graph/test_quantize.py index 45167b673..ec15da6a2 100644 --- a/tests/brevitas/graph/test_quantize.py +++ b/tests/brevitas/graph/test_quantize.py @@ -1,7 +1,14 @@ +import copy + import pytest_cases +import torch import torch.nn as nn +import torch.nn.utils.parametrize as parametrize +from brevitas.graph.base import _remove_parametrization_entries_state_dict from brevitas.graph.quantize import layerwise_quantize +from brevitas.utils.python_utils import recurse_getattr +from brevitas.utils.rotation_utils import RotationWeightParametrization @pytest_cases.parametrize( @@ -42,3 +49,96 @@ def test_layerwise_quantize_blacklist(kwargs): assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}" checked = True assert checked, f"Layer named {key} not found. Layer names are: {found_names}" + + +@pytest_cases.parametrize( + 'kwargs', + [ + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)), + 'rot_func': lambda tensor, + rot_mat, + K: torch.matmul(tensor, rot_mat), + 'key': '0', + 'expected': ""},]) +def test_layerwise_quantize_parametrized_modules(kwargs): + key = kwargs['key'] + exp = kwargs['expected'] + rot_mat = kwargs['rot_mat'] + rot_func = kwargs['rot_func'] + del kwargs['key'] + del kwargs['expected'] + del kwargs['rot_mat'] + del kwargs['rot_func'] + + model = kwargs["model"] + module = recurse_getattr(model, key) + # Register rotation parametrization to module + parametrize.register_parametrization( + module=module, + tensor_name="weight", + parametrization=RotationWeightParametrization( + rot_mat=nn.Parameter(rot_mat), + rot_func=rot_func, + axis=1, + K=None, + )) + qmodel = layerwise_quantize(**kwargs) + checked = False + found_names = [] + for n, m in qmodel.named_modules(): + found_names.append(n) + if n == key: + mt = str(type(m)) + assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}" + checked = True + assert checked, f"Layer named {key} not found. Layer names are: {found_names}" + + +@pytest_cases.parametrize( + 'kwargs', + [{ + 'model': nn.Sequential(nn.Linear(2, 3)), + 'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)), + 'rot_func': lambda tensor, + rot_mat, + K: torch.matmul(tensor, rot_mat), + 'key': '0', + 'expected_state_dict_keys': ['0.weight', '0.bias'],}]) +def test_remove_parametrization_entries_state_dict(kwargs): + key = kwargs['key'] + rot_mat = kwargs['rot_mat'] + rot_func = kwargs['rot_func'] + expected_state_dict_keys = kwargs['expected_state_dict_keys'] + del kwargs['key'] + del kwargs['rot_mat'] + del kwargs['rot_func'] + del kwargs['expected_state_dict_keys'] + + model = kwargs["model"] + module = recurse_getattr(model, key) + old_state_dict = copy.deepcopy(model.state_dict()) + # Register rotation parametrization to module + parametrize.register_parametrization( + module=module, + tensor_name="weight", + parametrization=RotationWeightParametrization( + rot_mat=nn.Parameter(rot_mat), + rot_func=rot_func, + axis=1, + K=None, + )) + # Retrieve state dict after parametrization + state_dict = model.state_dict() + # Remove parametrization entries from state dict + state_dict = _remove_parametrization_entries_state_dict(state_dict) + # Verify that all the expected keys in expected_state_dict_keys + # are present in state_dict + assert len(set(expected_state_dict_keys) - set(state_dict.keys())) == 0 + # Verify that keys match + for key, value in state_dict.items(): + # Verify that key is in the expected keys + assert key in expected_state_dict_keys, f"Unexpected key {key} in state_dict" + # Compare tensor values + assert torch.allclose(value, old_state_dict[key], rtol=0.0, atol=0.0), f"Value of tensor {value} does not match with that in the original state_dict"