Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functionalize #748

Closed
wants to merge 14 commits into from
540 changes: 254 additions & 286 deletions notebooks/01_quant_tensor_quant_conv2d_overview.ipynb

Large diffs are not rendered by default.

200 changes: 66 additions & 134 deletions notebooks/02_quant_activation_overview.ipynb

Large diffs are not rendered by default.

568 changes: 202 additions & 366 deletions notebooks/03_anatomy_of_a_quantizer.ipynb

Large diffs are not rendered by default.

356 changes: 229 additions & 127 deletions notebooks/Brevitas_TVMCon2021.ipynb

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions src/brevitas/core/quant/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ class BinaryQuant(brevitas.jit.ScriptModule):
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""

def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0):
def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):
super(BinaryQuant, self).__init__()
assert signed, "Unsigned binary quant not supported"
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
Expand All @@ -58,7 +59,7 @@ def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0):
@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
y = binary_sign_ste(x) * scale
y = binary_sign_ste(x) #* scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()

Expand Down Expand Up @@ -119,6 +120,6 @@ def __init__(
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
y = self.tensor_clamp_impl(x, -scale, scale)
y = binary_sign_ste(y) * scale
y = binary_sign_ste(y) #* scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()
13 changes: 7 additions & 6 deletions src/brevitas/core/quant/int_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ def max_int(self, bit_width):

@brevitas.jit.script_method
def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor:
y_int = self.to_int(scale, zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.to_int(scale, zero_point, bit_width, x)
# y = y_int - zero_point
# y = y * scale
y = self.delay_wrapper(x, y)
# print(f"Only once {y}")
return y


Expand Down Expand Up @@ -163,8 +164,8 @@ def forward(
zero_point: Tensor,
bit_width: Tensor,
x: Tensor) -> Tensor:
y_int = self.to_int(pre_scale, pre_zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.to_int(pre_scale, pre_zero_point, bit_width, x)
# y = y_int - zero_point
# y = y * scale
y = self.delay_wrapper(x, y)
return y
2 changes: 1 addition & 1 deletion src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def evaluate_loss(self, x, candidate):
self.set_local_loss_mode(True)
quant_value = self.proxy_forward(x)
if isinstance(quant_value, tuple):
quant_value = quant_value[0]
quant_value = quant_value.value
loss = self.mse_loss_fn(x, quant_value)
self.set_local_loss_mode(False)
return loss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def input_quant_symbolic_kwargs(cls, module):

@classmethod
def input_dequant_symbolic_kwargs(cls, module):
if module._cached_inp.scale is not None:
if module._cached_inp is not None:
return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp)
else:
return None
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def disable_act_quant_hook(self, module, inp, output):
if isinstance(module.tracked_module_list[0], QuantHardTanh):
inp = F.hardtanh(
inp, min_val=module.quant_injector.min_val, max_val=module.quant_injector.max_val)
return QuantTensor(value=inp, training=module.training)
return inp

def disable_act_quantization(self, model, is_training):
# If self.call_act_quantizer_impl is set to True, the quantization will be performed but the output
Expand Down
39 changes: 24 additions & 15 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.quant_tensor import _is_all_nested_not_none
from brevitas.quant_tensor import QuantTensor

from .utils import filter_kwargs
Expand All @@ -28,7 +29,7 @@ class _CachedIO:
def __init__(self, quant_tensor: QuantTensor, metadata_only: bool):
self.shape = quant_tensor.value.shape
if metadata_only:
self.quant_tensor = quant_tensor.set(value=None)
self.quant_tensor = quant_tensor.set(qt_value=None)
else:
self.quant_tensor = quant_tensor

Expand Down Expand Up @@ -166,15 +167,17 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]):
if not self.training and not self._export_mode and self.cache_inference_quant_inp:
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
else:
inp = QuantTensor(inp, training=self.training)
if not self.training and self.cache_inference_quant_inp:
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
# else:
# if not self.training and self.cache_inference_quant_inp:
# cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
# self._cached_inp = cached_inp
# Remove any naming metadata to avoid dowmstream errors
# Avoid inplace operations on the input in case of forward hooks
if not torch._C._get_tracing_state():
inp = inp.set(value=inp.value.rename(None))
if isinstance(inp, QuantTensor):
inp = inp.set(qt_value=inp.qt_value.rename(None))
else:
inp = inp.rename(None)
return inp

def pack_output(self, quant_output: QuantTensor):
Expand All @@ -184,7 +187,10 @@ def pack_output(self, quant_output: QuantTensor):
if self.return_quant_tensor:
return quant_output
else:
return quant_output.value
if isinstance(quant_output, QuantTensor):
return quant_output.value
else:
return quant_output


class QuantRecurrentLayerMixin(ExportMixin):
Expand Down Expand Up @@ -246,9 +252,10 @@ def gate_params_fwd(gate, quant_input):
acc_bit_width = None
quant_weight_ih = gate.input_weight()
quant_weight_hh = gate.hidden_weight()
if quant_input.bit_width is not None:
if isinstance(quant_input, QuantTensor):
acc_bit_width = None # TODO
if quant_input.scale is not None and quant_weight_ih.scale is not None:
if getattr(quant_input, 'scale', None) is not None and getattr(
quant_weight_ih, 'scale', None) is not None:
acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1)
acc_scale = quant_weight_ih.scale.view(acc_scale_shape)
acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape)
Expand All @@ -267,16 +274,16 @@ def maybe_quantize_input(self, inp):
quant_input = inp
if not self.quantize_output_only:
quant_input = self.io_quant(quant_input)
elif not isinstance(inp, QuantTensor):
quant_input = QuantTensor(quant_input)
# elif not isinstance(inp, QuantTensor):
# quant_input = QuantTensor(quant_input)
return quant_input

def maybe_quantize_state(self, inp, state, quant):
if state is None:
batch_size = inp.size(0) if self.cell.batch_first else inp.size(1)
quant_state = torch.zeros(
int(batch_size), self.hidden_size, dtype=inp.dtype, device=inp.device)
quant_state = QuantTensor(quant_state)
# quant_state = QuantTensor(quant_state)
else:
quant_state = quant(state)
return quant_state
Expand All @@ -303,7 +310,8 @@ def pack_quant_outputs(self, quant_outputs):
quant_output[2],
quant_output[3],
self.io_quant.is_signed,
self.training) for quant_output in quant_outputs]
self.training,
_allow_empty=True) for quant_output in quant_outputs]
else:
outputs = [torch.unsqueeze(o[0], dim=seq_dim) for o in quant_outputs]
if self.reverse_input:
Expand Down Expand Up @@ -331,7 +339,8 @@ def pack_quant_state(self, quant_state, quant):
quant_state[2],
quant_state[3],
quant.is_signed,
self.training)
training=self.training,
_allow_empty=True)
else:
quant_state = torch.unsqueeze(quant_state[0], dim=0)
return quant_state
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def quant_bias_zero_point(self):
if self.bias is None:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:
return self.bias_quant(self.bias).zero_point
bias_quant = self.bias_quant(self.bias)
if isinstance(bias_quant, QuantTensor):
return bias_quant.zero_point
else:
return None
else:
if self._cached_bias is None:
raise RuntimeError(
Expand Down
78 changes: 55 additions & 23 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from .utils import rename_state_dict_by_prefix


def return_value(tensor):
return tensor.value if isinstance(tensor, QuantTensor) else tensor


class QuantNonLinearActLayer(QuantNonLinearActMixin, QuantInputMixin, QuantLayerMixin, Module):
__metaclass__ = ABCMeta

Expand Down Expand Up @@ -303,61 +307,89 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe

# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(inp.value)
inp_value = getattr(inp, 'value', inp)
out = self.export_handler(inp_value)
self._set_global_is_quant_layer(False)
return out

quant_input = self.input_quant(inp)
# quant_input_value = getattr(quant_input, 'value', quant_input)
# quant_input_scale = getattr(quant_input, 'scale', None)
# quant_input_bitwidth = getattr(quant_input, 'bit_width', None)

quant_weight = self.quant_weight(quant_input)
# quant_weight_value = getattr(quant_weight, 'value', quant_weight)
# quant_weight_scale = getattr(quant_weight, 'scale', None)
# quant_weight_bitwidth = getattr(quant_weight, 'bit_width', None)
compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance(
quant_weight, QuantTensor)
if not (compute_output_quant_tensor or
self.is_output_quant_enabled) and self.return_quant_tensor:
raise RuntimeError("QuantLayer is not correctly configured")

if (self.return_quant_tensor or
(self.is_bias_quant_enabled and
(self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width))):
if quant_input.bit_width is not None and quant_weight.bit_width is not None:
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
output_bit_width = self.max_acc_bit_width(
quant_input.bit_width, quant_weight.bit_width)
if quant_input.scale is not None and quant_weight.scale is not None:

output_scale = self.quant_output_scale_impl(
inp, quant_input.scale, quant_weight.scale)
if quant_input.signed is not None:
output_signed = inp.signed or quant_weight.signed

quant_input_signed = quant_input.signed if isinstance(
quant_input, QuantTensor) else True
quant_weight_signed = quant_weight.signed if isinstance(
quant_weight, QuantTensor) else True
output_signed = quant_input_signed or quant_weight_signed

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width)
quant_bias_value = getattr(quant_bias, 'value', quant_bias)
quant_bias_scale = getattr(quant_bias, 'scale', None)
quant_bias_bitwidth = getattr(quant_bias, 'bit_width', None)
if not self.training and self.cache_inference_quant_bias:
self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False)

output_tensor = self.inner_forward_impl(
quant_input.value, quant_weight.value, quant_bias.value)
return_value(quant_input), return_value(quant_weight), return_value(quant_bias))

if (self.return_quant_tensor and output_scale is not None and
(quant_bias.scale is None or
(quant_bias.scale is not None and
quant_bias.scale.data_ptr() != output_scale.data_ptr()))):
output_scale_broadcast_shape = compute_channel_view_shape(inp, channel_dim=1)
output_zero_point = -quant_bias.value.view(
(quant_bias_scale is None or
(quant_bias_scale is not None and
quant_bias_scale.data_ptr() != output_scale.data_ptr()))):
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_broadcast_shape = compute_channel_view_shape(
inp, channel_dim=channel_dim)
output_zero_point = -quant_bias_value.view(
output_scale_broadcast_shape) / output_scale

if quant_bias.bit_width is not None and output_bit_width is not None:
if hasattr(quant_bias, 'bit_width'
) and quant_bias_bitwidth is not None and output_bit_width is not None:
output_bit_width = torch.where(
quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width)
quant_bias_bitwidth > output_bit_width, quant_bias_bitwidth, output_bit_width)
output_bit_width = output_bit_width + 1
else:
output_tensor = self.inner_forward_impl(quant_input.value, quant_weight.value, None)
output_tensor = self.inner_forward_impl(
return_value(quant_input), return_value(quant_weight), None)

if self.return_quant_tensor and not self.is_output_quant_enabled:
if (quant_input.zero_point is not None and quant_weight.zero_point is not None and
if (isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor) and
((quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any())):
raise RuntimeError("Computing zero point of output accumulator not supported yet.")
elif quant_input.zero_point is not None and output_zero_point is None:
output_zero_point = quant_input.zero_point
elif self.return_quant_tensor and output_zero_point is None:
output_zero_point = torch.zeros(1).type_as(output_tensor)

quant_output = QuantTensor(
value=output_tensor,
scale=output_scale,
zero_point=output_zero_point,
bit_width=output_bit_width,
signed=output_signed,
training=self.training)
if not self.return_quant_tensor or not compute_output_quant_tensor:
quant_output = output_tensor
else:
quant_output = QuantTensor.from_fake_quantized(
output_tensor,
scale=output_scale,
zero_point=output_zero_point,
bit_width=output_bit_width,
signed=output_signed,
training=self.training)
quant_output = self.output_quant(quant_output)
return self.pack_output(quant_output)
6 changes: 3 additions & 3 deletions src/brevitas/nn/quant_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def multi_head_attention(
# Mark dimensions through named tensors.
if not torch._C._get_tracing_state():
if isinstance(query, QuantTensor):
query.value.rename_('L', 'N', 'E')
query.qt_value.rename_('L', 'N', 'E')
else:
query.rename_('L', 'N', 'E')
# self-attention
Expand All @@ -426,7 +426,7 @@ def multi_head_attention(
if not torch._C._get_tracing_state():
for t in [query, key, value]:
if isinstance(t, QuantTensor):
t.value.rename_('L', 'N', 'E')
t.qt_value.rename_('L', 'N', 'E')
else:
t.rename_('L', 'N', 'E')
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
Expand Down Expand Up @@ -573,7 +573,7 @@ def multi_head_attention(
# Remove names to avoid errors un unsupported downstream ops
if not torch._C._get_tracing_state():
if isinstance(attn_output, QuantTensor):
attn_output.value.rename_(None)
attn_output.qt_value.rename_(None)
else:
attn_output.rename_(None)

Expand Down
Loading
Loading