Skip to content

Commit

Permalink
Feat (delay): move delay to proxies
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 30, 2024
1 parent 39ce837 commit ecfb3d7
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 122 deletions.
15 changes: 2 additions & 13 deletions src/brevitas/core/quant/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import brevitas
from brevitas.core.bit_width import BitWidthConst
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import binary_sign_ste

Expand All @@ -22,7 +21,6 @@ class BinaryQuant(brevitas.jit.ScriptModule):
Args:
scaling_impl (Module): Module that returns a scale factor.
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0
Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width.
Expand All @@ -48,19 +46,17 @@ class BinaryQuant(brevitas.jit.ScriptModule):
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""

def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):
def __init__(self, scaling_impl: Module, signed: bool = True):
super(BinaryQuant, self).__init__()
assert signed, "Unsigned binary quant not supported"
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
y = binary_sign_ste(x) * scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()


Expand All @@ -74,7 +70,6 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule):
Args:
scaling_impl (Module): Module that returns a scale factor.
tensor_clamp_impl (Module): Module that performs tensor-wise clamping. Default TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0
Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width.
Expand Down Expand Up @@ -104,22 +99,16 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule):
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""

def __init__(
self,
scaling_impl: Module,
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
def __init__(self, scaling_impl: Module, tensor_clamp_impl: Module = TensorClamp()):
super(ClampedBinaryQuant, self).__init__()
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.tensor_clamp_impl = tensor_clamp_impl

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
y = self.tensor_clamp_impl(x, -scale, scale)
y = binary_sign_ste(y) * scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()
6 changes: 1 addition & 5 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.nn import Module

import brevitas
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import round_ste

Expand Down Expand Up @@ -201,12 +200,10 @@ class TruncIntQuant(brevitas.jit.ScriptModule):
"""
"""

def __init__(
self, float_to_int_impl: Module, bit_width_impl: Module, quant_delay_steps: int = 0):
def __init__(self, float_to_int_impl: Module, bit_width_impl: Module):
super(TruncIntQuant, self).__init__()
self.msb_clamp_bit_width_impl = bit_width_impl
self.float_to_int_impl = float_to_int_impl
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor,
Expand All @@ -221,7 +218,6 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor,
y = self.float_to_int_impl(y)
y = y - zero_point
y = y * scale
y = self.delay_wrapper(x, y)
return y, scale, zero_point, output_bit_width


Expand Down
33 changes: 12 additions & 21 deletions src/brevitas/core/quant/int_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import brevitas
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.quant.delay import DelayWrapper
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int

Expand All @@ -24,7 +23,6 @@ class IntQuant(brevitas.jit.ScriptModule):
float_to_int_impl (Module): Module that performs the conversion from floating point to
integer representation. Default: RoundSte()
tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0
Returns:
Tensor: Quantized output in de-quantized format.
Expand All @@ -48,19 +46,17 @@ class IntQuant(brevitas.jit.ScriptModule):
__constants__ = ['signed', 'narrow_range']

def __init__(
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp()):
super(IntQuant, self).__init__()
self.float_to_int_impl = float_to_int_impl
self.tensor_clamp_impl = tensor_clamp_impl
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
Expand All @@ -87,7 +83,6 @@ def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tenso
y_int = self.to_int(scale, zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.delay_wrapper(x, y)
return y


Expand All @@ -102,7 +97,6 @@ class DecoupledIntQuant(brevitas.jit.ScriptModule):
float_to_int_impl (Module): Module that performs the conversion from floating point to
integer representation. Default: RoundSte()
tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0
Returns:
Tensor: Quantized output in de-quantized format.
Expand All @@ -124,19 +118,17 @@ class DecoupledIntQuant(brevitas.jit.ScriptModule):
__constants__ = ['signed', 'narrow_range']

def __init__(
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp()):
super(DecoupledIntQuant, self).__init__()
self.float_to_int_impl = float_to_int_impl
self.tensor_clamp_impl = tensor_clamp_impl
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
Expand Down Expand Up @@ -172,5 +164,4 @@ def forward(
y_int = self.to_int(pre_scale, pre_zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.delay_wrapper(x, y)
return y
3 changes: 0 additions & 3 deletions src/brevitas/core/quant/ternary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import brevitas
from brevitas.core.bit_width import BitWidthConst
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import ternary_sign_ste

Expand Down Expand Up @@ -57,13 +56,11 @@ def __init__(self, scaling_impl: Module, threshold: float, quant_delay_steps: in
self.threshold = threshold
self.bit_width = BitWidthConst(2)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
mask = x.abs().gt(self.threshold * scale)
y = mask.float() * ternary_sign_ste(x)
y = y * scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()
12 changes: 8 additions & 4 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from brevitas import config
from brevitas import is_dynamo_compiling
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.quant.delay import DelayWrapper
from brevitas.function import max_int
from brevitas.inject import BaseInjector as Injector
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO
Expand Down Expand Up @@ -96,6 +96,8 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.cache_class = None # To be redefined by each class
self.quant_tensor_class = None # To be redefined by each class
self.skip_create_quant_tensor = False
quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@property
def input_view_impl(self):
Expand Down Expand Up @@ -138,11 +140,13 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
else:
out = self.create_quant_tensor(out)
else:
out = self.tensor_quant(x)
quant_value, *quant_args = self.tensor_quant(x)
quant_args = tuple(quant_args)
quant_value = self.delay_wrapper(x, quant_value)
if self.skip_create_quant_tensor:
out = out[0]
out = quant_value
else:
out = self.create_quant_tensor(out)
out = self.create_quant_tensor((quant_value,) + quant_args)
if not self.training and self.cache_inference_quant_weight and self._cached_weight is None:
self._cached_weight = self.cache_class(
out.detach(),
Expand Down
32 changes: 20 additions & 12 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import brevitas
from brevitas import is_dynamo_compiling
from brevitas.core.quant.delay import DelayWrapper
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO
Expand Down Expand Up @@ -99,6 +100,8 @@ def __init__(self, quant_layer, quant_injector):
self.cache_quant_io_metadata_only = True
self.cache_class = None
self.skip_create_quant_tensor = False
quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@property
def input_view_impl(self):
Expand Down Expand Up @@ -176,31 +179,33 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
out = self.fused_activation_quant_proxy.activation_impl(y)
out = self.export_handler(out)
elif not self.is_quant_enabled:
# A tuple helps later with control flows
# The second None value is used later
# If quant is not enabled, we still apply input_view in the case of groupwise + padding
y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y))
y = (y, None)
out = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y))
out = (out, None)
else:
y = self.fused_activation_quant_proxy(y)
out = self.fused_activation_quant_proxy(y)
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor

quant_value, *quant_args = out
quant_args = tuple(quant_args)
quant_value = self.delay_wrapper(y, quant_value)
if self.skip_create_quant_tensor:
out = y[0]
out = quant_value
else:
# If the second value (i.e., scale) is None, then quant is disabled
if y[1] is not None:
out = self.create_quant_tensor(y)
if out[1] is not None:
out = self.create_quant_tensor((quant_value,) + quant_args)
elif self.is_passthrough_act and isinstance(x, QuantTensor):
# preserve scale/zp/bit/sign even without output quant
y = y[0]
out = self.create_quant_tensor(y, x=x)
out = quant_value
out = self.create_quant_tensor(out, x=x)
else:
out = y[0]
out = quant_value

if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor):
cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only)
Expand Down Expand Up @@ -267,6 +272,8 @@ class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.skip_create_quant_tensor = False
quant_delay_steps = self.quant_injector.quant_delay_steps if 'quant_delay_steps' in self.quant_injector else None
self.delay_wrapper = DelayWrapper(quant_delay_steps)

def bit_width(self):
if not self.is_quant_enabled:
Expand All @@ -285,6 +292,7 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
else:
out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
out_value = self.delay_wrapper(x, out_value)
if self.skip_create_quant_tensor:
return out_value
return IntQuantTensor(
Expand Down
20 changes: 0 additions & 20 deletions tests/brevitas/core/binary_quant_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@
__all__ = [
'binary_quant',
'clamped_binary_quant',
'delayed_binary_quant',
'delayed_clamped_binary_quant',
'binary_quant_impl_all',
'binary_quant_all', # noqa
'delayed_binary_quant_all', # noqa
]


Expand Down Expand Up @@ -43,21 +40,4 @@ def clamped_binary_quant(scaling_impl_all):
return ClampedBinaryQuant(scaling_impl=scaling_impl_all)


@pytest_cases.fixture()
def delayed_binary_quant(scaling_impl_all, quant_delay_steps):
"""
Delayed BinaryQuant with all variants of scaling
"""
return BinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps)


@pytest_cases.fixture()
def delayed_clamped_binary_quant(scaling_impl_all, quant_delay_steps):
"""
ClampedBinaryQuant with all variants of scaling
"""
return ClampedBinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps)


fixture_union('binary_quant_all', ['binary_quant', 'clamped_binary_quant'])
fixture_union('delayed_binary_quant_all', ['delayed_binary_quant', 'delayed_clamped_binary_quant'])
10 changes: 0 additions & 10 deletions tests/brevitas/core/shared_quant_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from brevitas.core.scaling import ParameterScaling

__all__ = [
'quant_delay_steps',
'const_scaling_impl',
'parameter_scaling_impl',
'standalone_scaling_init',
Expand All @@ -18,15 +17,6 @@
]


@pytest_cases.fixture()
@pytest_cases.parametrize('steps', [1, 10])
def quant_delay_steps(steps):
"""
Non-zero steps to delay quantization
"""
return steps


@pytest_cases.fixture()
def const_scaling_impl(standalone_scaling_init):
"""
Expand Down
Loading

0 comments on commit ecfb3d7

Please sign in to comment.