Skip to content

Commit

Permalink
Enable quantization of parametrized layers
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 14, 2025
1 parent 23c5814 commit 2b21dfd
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 4 deletions.
48 changes: 47 additions & 1 deletion src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,34 @@ def apply(self, model: Module, *model_args, **model_kwargs):
return model


def _remove_parametrization_entries_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
# Keys for values related to parametrizations
keys_to_remove = []
# Keys/values corresponding to the original tensors before parametrizations
keys_values_to_add = []
# Iterate over state_dict identifying the keys related to parametrizations
for key, value in state_dict.items():
split_key = key.split(".")
# Keys for values before parametrizations have the format "prefix.parametrizations.tensor_name.original"
if len(split_key
) >= 3 and split_key[-3] == "parametrizations" and split_key[-1] == "original":
tensor_name = split_key[-2]
# Name of dictionary entry is "prefix.tensro_name"
keys_values_to_add.append((".".join(split_key[:-3] + [tensor_name]), value))
# Keys corresponding to the parametrizations attached to the model need to be removed
# to make sure the dictionary can be loaded with no missing/unused keys
# NOTE: For safety, an additional check could be added as this logic would not work if a model
# without parametrizations has any key containing "parametrizations"
if "parametrizations" in split_key:
keys_to_remove.append(key)
# Apply changes in-place to the state_dict
for key in keys_to_remove:
del state_dict[key]
for key, value in keys_values_to_add:
state_dict[key] = value
return state_dict


class ModuleToModule(GraphTransform, ABC):

def __init__(self, new_module_class, **kwargs):
Expand Down Expand Up @@ -159,7 +187,25 @@ def _init_new_module(self, old_module: Module, name=None):
def _replace_old_module(self, model, old_module, new_module, load_state_dict=True):
replace_module(model, old_module, new_module)
if load_state_dict:
new_module.load_state_dict(old_module.state_dict())
old_module_state_dict = old_module.state_dict()
# If parametrizations are present in old_module, the state_dict needs
# to be processed beforehand
if parametrize.is_parametrized(old_module):
old_module_state_dict = _remove_parametrization_entries_state_dict(
old_module_state_dict)
# Strict can be set to True, since potential parametrizations were
# accounted for
new_module.load_state_dict(old_module_state_dict)
# If the old module is parametrized, these need to be transferred to the new module.
# The method transfer_parametrizations_and_params as it can result in parameter ties
# being broken
# NOTE: unsafe is set to True for efficiency, as the checks should have been done
# when first registering the parametrization to old_module
if parametrize.is_parametrized(old_module):
for tensor_name in old_module.parametrizations:
for param_func in old_module.parametrizations[tensor_name]:
parametrize.register_parametrization(
new_module, tensor_name, param_func, unsafe=True)


class InsertModuleCallAfter(GraphTransform):
Expand Down
8 changes: 5 additions & 3 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

import brevitas
from brevitas.graph.base import InsertModuleCallAfter
Expand Down Expand Up @@ -511,7 +512,7 @@ def find_module(
Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its
Linear submodules.
"""
if _module_class_name(type(model)) in layer_map.keys():
if _module_class_name(parametrize.type_before_parametrizations(model)) in layer_map.keys():
module_to_replace.append(model)
else:
for name, module in model.named_children():
Expand All @@ -532,8 +533,9 @@ def layerwise_layer_handler(
find_module(model, layer_map, module_to_replace, name_blacklist)
rewriters = []
for module in module_to_replace:
if layer_map[_module_class_name(type(module))] is not None:
quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type(module))]
if layer_map[_module_class_name(
parametrize.type_before_parametrizations(module))] is not None:
quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))]
rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs)
rewriters.append(rewriter)
for rewriter in rewriters:
Expand Down
100 changes: 100 additions & 0 deletions tests/brevitas/graph/test_quantize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import copy

import pytest_cases
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

from brevitas.graph.base import _remove_parametrization_entries_state_dict
from brevitas.graph.quantize import layerwise_quantize
from brevitas.utils.python_utils import recurse_getattr
from brevitas.utils.rotation_utils import RotationWeightParametrization


@pytest_cases.parametrize(
Expand Down Expand Up @@ -42,3 +49,96 @@ def test_layerwise_quantize_blacklist(kwargs):
assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}"
checked = True
assert checked, f"Layer named {key} not found. Layer names are: {found_names}"


@pytest_cases.parametrize(
'kwargs',
[
{
'model': nn.Sequential(nn.Linear(2, 3)),
'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)),
'rot_func': lambda tensor,
rot_mat,
K: torch.matmul(tensor, rot_mat),
'key': '0',
'expected': "<class 'torch.nn.utils.parametrize.ParametrizedQuantLinear'>"},])
def test_layerwise_quantize_parametrized_modules(kwargs):
key = kwargs['key']
exp = kwargs['expected']
rot_mat = kwargs['rot_mat']
rot_func = kwargs['rot_func']
del kwargs['key']
del kwargs['expected']
del kwargs['rot_mat']
del kwargs['rot_func']

model = kwargs["model"]
module = recurse_getattr(model, key)
# Register rotation parametrization to module
parametrize.register_parametrization(
module=module,
tensor_name="weight",
parametrization=RotationWeightParametrization(
rot_mat=nn.Parameter(rot_mat),
rot_func=rot_func,
axis=1,
K=None,
))
qmodel = layerwise_quantize(**kwargs)
checked = False
found_names = []
for n, m in qmodel.named_modules():
found_names.append(n)
if n == key:
mt = str(type(m))
assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}"
checked = True
assert checked, f"Layer named {key} not found. Layer names are: {found_names}"


@pytest_cases.parametrize(
'kwargs',
[{
'model': nn.Sequential(nn.Linear(2, 3)),
'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)),
'rot_func': lambda tensor,
rot_mat,
K: torch.matmul(tensor, rot_mat),
'key': '0',
'expected_state_dict_keys': ['0.weight', '0.bias'],}])
def test_remove_parametrization_entries_state_dict(kwargs):
key = kwargs['key']
rot_mat = kwargs['rot_mat']
rot_func = kwargs['rot_func']
expected_state_dict_keys = kwargs['expected_state_dict_keys']
del kwargs['key']
del kwargs['rot_mat']
del kwargs['rot_func']
del kwargs['expected_state_dict_keys']

model = kwargs["model"]
module = recurse_getattr(model, key)
old_state_dict = copy.deepcopy(model.state_dict())
# Register rotation parametrization to module
parametrize.register_parametrization(
module=module,
tensor_name="weight",
parametrization=RotationWeightParametrization(
rot_mat=nn.Parameter(rot_mat),
rot_func=rot_func,
axis=1,
K=None,
))
# Retrieve state dict after parametrization
state_dict = model.state_dict()
# Remove parametrization entries from state dict
state_dict = _remove_parametrization_entries_state_dict(state_dict)
# Verify that all the expected keys in expected_state_dict_keys
# are present in state_dict
assert len(set(expected_state_dict_keys) - set(state_dict.keys())) == 0
# Verify that keys match
for key, value in state_dict.items():
# Verify that key is in the expected keys
assert key in expected_state_dict_keys, f"Unexpected key {key} in state_dict"
# Compare tensor values
assert torch.allclose(value, old_state_dict[key], rtol=0.0, atol=0.0), f"Value of tensor {value} does not match with that in the original state_dict"

0 comments on commit 2b21dfd

Please sign in to comment.