Skip to content

Commit

Permalink
Fix (llm): add checks for group dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 7, 2023
1 parent 513ab4d commit 29b9e35
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def pack_int_weights(self, bit_width, int_weights):
if bit_width == 8:
return int_weights
elif bit_width == 4 or bit_width == 2:
assert int_weights.shape[1] * bit_width % 8 == 0, "Number of columns multiplied by the bit-width must be a multiple of 8"
packed_int_weights = torch.zeros(
(int_weights.shape[0], int_weights.shape[1] * bit_width // 8),
device=int_weights.device,
Expand Down
12 changes: 6 additions & 6 deletions src/brevitas_examples/llm/llm_quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,22 @@ class WeightSymmetricGroupQuantMixin(ExtendedInjector):
@value
def expanded_scaling_shape(module, block_size):
if isinstance(module, nn.Conv2d):
return module.weight.size(0), module.weight.size(1) // block_size, block_size, module.weight.size(2), module.weight.size(3)
return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, block_size, module.weight.size(2), module.weight.size(3)
elif isinstance(module, nn.Linear):
return module.weight.size(0), module.weight.size(1) // block_size, block_size
return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, block_size
elif isinstance(module, nn.Embedding):
return module.weight.size(0), module.weight.size(1) // block_size, block_size
return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, block_size
else:
raise RuntimeError("Module not supported.")

@value
def scaling_shape(module, block_size):
if isinstance(module, nn.Conv2d):
return module.weight.size(0), module.weight.size(1) // block_size, 1, module.weight.size(2), module.weight.size(3)
return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, 1, module.weight.size(2), module.weight.size(3)
elif isinstance(module, nn.Linear):
return module.weight.size(0), module.weight.size(1) // block_size, 1
return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, 1
elif isinstance(module, nn.Embedding):
return module.weight.size(0), module.weight.size(1) // block_size, 1
return module.weight.size(0), (module.weight.size(1) + block_size - 1) // block_size, 1
else:
raise RuntimeError("Module not supported.")

Expand Down

0 comments on commit 29b9e35

Please sign in to comment.