Skip to content

Commit

Permalink
Feat (quant_tensor): refactor to store int values
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 28, 2023
1 parent 63370f4 commit fd25826
Show file tree
Hide file tree
Showing 19 changed files with 1,093 additions and 1,012 deletions.
540 changes: 254 additions & 286 deletions notebooks/01_quant_tensor_quant_conv2d_overview.ipynb

Large diffs are not rendered by default.

200 changes: 66 additions & 134 deletions notebooks/02_quant_activation_overview.ipynb

Large diffs are not rendered by default.

494 changes: 244 additions & 250 deletions notebooks/03_anatomy_of_a_quantizer.ipynb

Large diffs are not rendered by default.

426 changes: 243 additions & 183 deletions notebooks/Brevitas_TVMCon2021.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/brevitas/core/quant/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0):
@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
y = binary_sign_ste(x) * scale
y = binary_sign_ste(x) #* scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()

Expand Down Expand Up @@ -119,6 +119,6 @@ def __init__(
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
y = self.tensor_clamp_impl(x, -scale, scale)
y = binary_sign_ste(y) * scale
y = binary_sign_ste(y) #* scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()
7 changes: 4 additions & 3 deletions src/brevitas/core/quant/int_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ def max_int(self, bit_width):

@brevitas.jit.script_method
def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor:
y_int = self.to_int(scale, zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.to_int(scale, zero_point, bit_width, x)
# y = y_int - zero_point
# y = y * scale
y = self.delay_wrapper(x, y)
# print(f"Only once {y}")
return y


Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def evaluate_loss(self, x, candidate):
self.set_local_loss_mode(True)
quant_value = self.proxy_forward(x)
if isinstance(quant_value, tuple):
quant_value = quant_value[0]
quant_value = quant_value.value
loss = self.mse_loss_fn(x, quant_value)
self.set_local_loss_mode(False)
return loss
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def disable_act_quant_hook(self, module, inp, output):
if isinstance(module.tracked_module_list[0], QuantHardTanh):
inp = F.hardtanh(
inp, min_val=module.quant_injector.min_val, max_val=module.quant_injector.max_val)
return QuantTensor(value=inp, training=module.training)
return inp

def disable_act_quantization(self, model, is_training):
# If self.call_act_quantizer_impl is set to True, the quantization will be performed but the output
Expand Down
30 changes: 20 additions & 10 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,18 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]):
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
else:
inp = QuantTensor(inp, training=self.training)
# inp = QuantTensor(inp, scale=torch.tensor(1.0, device=inp.device, dtype=inp.dtype), training=self.training)
if not self.training and self.cache_inference_quant_inp:
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
# print(inp)
# Remove any naming metadata to avoid dowmstream errors
# Avoid inplace operations on the input in case of forward hooks
if not torch._C._get_tracing_state():
inp = inp.set(value=inp.value.rename(None))
if isinstance(inp, QuantTensor):
inp = inp.set(qt_value=inp.qt_value.rename(None))
else:
inp = inp.rename(None)
return inp

def pack_output(self, quant_output: QuantTensor):
Expand All @@ -184,7 +188,10 @@ def pack_output(self, quant_output: QuantTensor):
if self.return_quant_tensor:
return quant_output
else:
return quant_output.value
if isinstance(quant_output, QuantTensor):
return quant_output.value
else:
return quant_output


class QuantRecurrentLayerMixin(ExportMixin):
Expand Down Expand Up @@ -246,9 +253,10 @@ def gate_params_fwd(gate, quant_input):
acc_bit_width = None
quant_weight_ih = gate.input_weight()
quant_weight_hh = gate.hidden_weight()
if quant_input.bit_width is not None:
if isinstance(quant_input, QuantTensor):
acc_bit_width = None # TODO
if quant_input.scale is not None and quant_weight_ih.scale is not None:
if getattr(quant_input, 'scale', None) is not None and getattr(
quant_weight_ih, 'scale', None) is not None:
acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1)
acc_scale = quant_weight_ih.scale.view(acc_scale_shape)
acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape)
Expand All @@ -267,16 +275,16 @@ def maybe_quantize_input(self, inp):
quant_input = inp
if not self.quantize_output_only:
quant_input = self.io_quant(quant_input)
elif not isinstance(inp, QuantTensor):
quant_input = QuantTensor(quant_input)
# elif not isinstance(inp, QuantTensor):
# quant_input = QuantTensor(quant_input)
return quant_input

def maybe_quantize_state(self, inp, state, quant):
if state is None:
batch_size = inp.size(0) if self.cell.batch_first else inp.size(1)
quant_state = torch.zeros(
int(batch_size), self.hidden_size, dtype=inp.dtype, device=inp.device)
quant_state = QuantTensor(quant_state)
# quant_state = QuantTensor(quant_state)
else:
quant_state = quant(state)
return quant_state
Expand All @@ -303,7 +311,8 @@ def pack_quant_outputs(self, quant_outputs):
quant_output[2],
quant_output[3],
self.io_quant.is_signed,
self.training) for quant_output in quant_outputs]
self.training,
_allow_empty=True) for quant_output in quant_outputs]
else:
outputs = [torch.unsqueeze(o[0], dim=seq_dim) for o in quant_outputs]
if self.reverse_input:
Expand Down Expand Up @@ -331,7 +340,8 @@ def pack_quant_state(self, quant_state, quant):
quant_state[2],
quant_state[3],
quant.is_signed,
self.training)
training=self.training,
_allow_empty=True)
else:
quant_state = torch.unsqueeze(quant_state[0], dim=0)
return quant_state
Expand Down
74 changes: 52 additions & 22 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from .utils import rename_state_dict_by_prefix


def return_value(tensor):
return tensor.value if isinstance(tensor, QuantTensor) else tensor


class QuantNonLinearActLayer(QuantNonLinearActMixin, QuantInputMixin, QuantLayerMixin, Module):
__metaclass__ = ABCMeta

Expand Down Expand Up @@ -308,56 +312,82 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
return out

quant_input = self.input_quant(inp)
# quant_input_value = getattr(quant_input, 'value', quant_input)
# quant_input_scale = getattr(quant_input, 'scale', None)
# quant_input_bitwidth = getattr(quant_input, 'bit_width', None)

quant_weight = self.quant_weight(quant_input)
# quant_weight_value = getattr(quant_weight, 'value', quant_weight)
# quant_weight_scale = getattr(quant_weight, 'scale', None)
# quant_weight_bitwidth = getattr(quant_weight, 'bit_width', None)

if ((not isinstance(quant_input, QuantTensor) or not isinstance(quant_weight, QuantTensor))
and not self.is_output_quant_enabled) and self.return_quant_tensor:
raise RuntimeError("QuantLayer is not correctly configured")

if (self.return_quant_tensor or
(self.is_bias_quant_enabled and
(self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width))):
if quant_input.bit_width is not None and quant_weight.bit_width is not None:
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
output_bit_width = self.max_acc_bit_width(
quant_input.bit_width, quant_weight.bit_width)
if quant_input.scale is not None and quant_weight.scale is not None:

output_scale = self.quant_output_scale_impl(
inp, quant_input.scale, quant_weight.scale)
if quant_input.signed is not None:
output_signed = inp.signed or quant_weight.signed

quant_input_signed = quant_input.signed if isinstance(
quant_input, QuantTensor) else True
quant_weight_signed = quant_weight.signed if isinstance(
quant_weight, QuantTensor) else True
output_signed = quant_input_signed or quant_weight_signed

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width)
quant_bias_value = getattr(quant_bias, 'value', quant_bias)
quant_bias_scale = getattr(quant_bias, 'scale', None)
quant_bias_bitwidth = getattr(quant_bias, 'bit_width', None)
if not self.training and self.cache_inference_quant_bias:
self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False)

output_tensor = self.inner_forward_impl(
quant_input.value, quant_weight.value, quant_bias.value)
return_value(quant_input), return_value(quant_weight), return_value(quant_bias))

if (self.return_quant_tensor and output_scale is not None and
(quant_bias.scale is None or
(quant_bias.scale is not None and
quant_bias.scale.data_ptr() != output_scale.data_ptr()))):
output_scale_broadcast_shape = compute_channel_view_shape(inp, channel_dim=1)
output_zero_point = -quant_bias.value.view(
(quant_bias_scale is None or
(quant_bias_scale is not None and
quant_bias_scale.data_ptr() != output_scale.data_ptr()))):
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_broadcast_shape = compute_channel_view_shape(
inp, channel_dim=channel_dim)
output_zero_point = -quant_bias_value.view(
output_scale_broadcast_shape) / output_scale

if quant_bias.bit_width is not None and output_bit_width is not None:
if hasattr(quant_bias, 'bit_width'
) and quant_bias_bitwidth is not None and output_bit_width is not None:
output_bit_width = torch.where(
quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width)
quant_bias_bitwidth > output_bit_width, quant_bias_bitwidth, output_bit_width)
output_bit_width = output_bit_width + 1
else:
output_tensor = self.inner_forward_impl(quant_input.value, quant_weight.value, None)
output_tensor = self.inner_forward_impl(
return_value(quant_input), return_value(quant_weight), None)

if self.return_quant_tensor and not self.is_output_quant_enabled:
if (quant_input.zero_point is not None and quant_weight.zero_point is not None and
if (isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor) and
((quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any())):
raise RuntimeError("Computing zero point of output accumulator not supported yet.")
elif quant_input.zero_point is not None and output_zero_point is None:
output_zero_point = quant_input.zero_point
elif self.return_quant_tensor and output_zero_point is None:
output_zero_point = torch.zeros(1).type_as(output_tensor)

quant_output = QuantTensor(
value=output_tensor,
scale=output_scale,
zero_point=output_zero_point,
bit_width=output_bit_width,
signed=output_signed,
training=self.training)
if not self.return_quant_tensor or (output_scale is None and output_zero_point is None):
quant_output = output_tensor
else:
quant_output = QuantTensor.from_fake_quantized(
output_tensor,
scale=output_scale,
zero_point=output_zero_point,
bit_width=output_bit_width,
signed=output_signed,
training=self.training)
quant_output = self.output_quant(quant_output)
return self.pack_output(quant_output)
6 changes: 3 additions & 3 deletions src/brevitas/nn/quant_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def multi_head_attention(
# Mark dimensions through named tensors.
if not torch._C._get_tracing_state():
if isinstance(query, QuantTensor):
query.value.rename_('L', 'N', 'E')
query.qt_value.rename_('L', 'N', 'E')
else:
query.rename_('L', 'N', 'E')
# self-attention
Expand All @@ -426,7 +426,7 @@ def multi_head_attention(
if not torch._C._get_tracing_state():
for t in [query, key, value]:
if isinstance(t, QuantTensor):
t.value.rename_('L', 'N', 'E')
t.qt_value.rename_('L', 'N', 'E')
else:
t.rename_('L', 'N', 'E')
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
Expand Down Expand Up @@ -573,7 +573,7 @@ def multi_head_attention(
# Remove names to avoid errors un unsupported downstream ops
if not torch._C._get_tracing_state():
if isinstance(attn_output, QuantTensor):
attn_output.value.rename_(None)
attn_output.qt_value.rename_(None)
else:
attn_output.rename_(None)

Expand Down
65 changes: 34 additions & 31 deletions src/brevitas/nn/quant_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import Int32Bias
from brevitas.quant import Uint8ActPerTensorFloat
from brevitas.quant_tensor import _get_dequantize_tensor
from brevitas.quant_tensor import QuantTensor

QuantTupleShortEnabled = List[Tuple[Tensor, Tensor, Tensor, Tensor]]
QuantTupleShortDisabled = List[Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]]
Expand Down Expand Up @@ -416,10 +418,10 @@ def forward(self, inp, state):
quant_input = self.maybe_quantize_input(inp)
quant_weight_ih, quant_weight_hh, quant_bias = self.gate_params_fwd(
self.gate_params, quant_input)
if quant_bias.value is None:
if getattr(quant_bias, 'value', quant_bias) is None:
quant_bias = torch.tensor(0., device=quant_input.value.device)
else:
quant_bias = quant_bias.value
quant_bias = _get_dequantize_tensor(quant_bias)
quant_state = self.maybe_quantize_state(quant_input.value, state, self.cell.output_quant)
if self.export_mode:
cell = self.export_handler
Expand All @@ -428,10 +430,10 @@ def forward(self, inp, state):
else:
cell = self.cell
quant_outputs = cell(
quant_input.value,
quant_state.value,
quant_weight_ih.value,
quant_weight_hh.value,
_get_dequantize_tensor(quant_input),
_get_dequantize_tensor(quant_state),
_get_dequantize_tensor(quant_weight_ih),
_get_dequantize_tensor(quant_weight_hh),
quant_bias)
quant_output = self.pack_quant_outputs(quant_outputs)
quant_state = self.pack_quant_state(quant_outputs[-1], self.cell.output_quant)
Expand Down Expand Up @@ -666,6 +668,7 @@ def fast_cell(self):

def forward(self, inp, hidden_state, cell_state):
quant_input = self.maybe_quantize_input(inp)
quant_input_value = _get_dequantize_tensor(quant_input)
quant_weight_ii, quant_weight_hi, quant_bias_input = self.gate_params_fwd(
self.input_gate_params, quant_input)
quant_weight_ic, quant_weight_hc, quant_bias_cell = self.gate_params_fwd(
Expand All @@ -680,26 +683,26 @@ def forward(self, inp, hidden_state, cell_state):
quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd(
self.forget_gate_params, quant_input)
# Handle None bias by setting it 0.
if quant_bias_input.value is None:
quant_bias_input = torch.tensor(0., device=quant_input.value.device)
if getattr(quant_bias_input, 'value', quant_bias_input) is None:
quant_bias_input = torch.tensor(0., device=quant_input_value.device)
else:
quant_bias_input = quant_bias_input.value
if quant_bias_forget.value is None:
quant_bias_forget = torch.tensor(0., device=quant_input.value.device)
quant_bias_input = _get_dequantize_tensor(quant_bias_input)
if getattr(quant_bias_forget, 'value', quant_bias_forget) is None:
quant_bias_forget = torch.tensor(0., device=quant_input_value.device)
else:
quant_bias_forget = quant_bias_forget.value
if quant_bias_cell.value is None:
quant_bias_cell = torch.tensor(0., device=quant_input.value.device)
quant_bias_forget = _get_dequantize_tensor(quant_bias_forget)
if getattr(quant_bias_cell, 'value', quant_bias_cell) is None:
quant_bias_cell = torch.tensor(0., device=quant_input_value.device)
else:
quant_bias_cell = quant_bias_cell.value
if quant_bias_output.value is None:
quant_bias_output = torch.tensor(0., device=quant_input.value.device)
quant_bias_cell = _get_dequantize_tensor(quant_bias_cell)
if getattr(quant_bias_output, 'value', quant_bias_output) is None:
quant_bias_output = torch.tensor(0., device=quant_input_value.device)
else:
quant_bias_output = quant_bias_output.value
quant_bias_output = _get_dequantize_tensor(quant_bias_output)
quant_hidden_state = self.maybe_quantize_state(
quant_input.value, hidden_state, self.cell.output_quant)
quant_input_value, hidden_state, self.cell.output_quant)
quant_cell_state = self.maybe_quantize_state(
quant_input.value, cell_state, self.cell.cell_state_quant)
quant_input_value, cell_state, self.cell.cell_state_quant)
# Pick cell impl
if self.export_mode:
cell = self.export_handler
Expand All @@ -708,17 +711,17 @@ def forward(self, inp, hidden_state, cell_state):
else:
cell = self.cell
quant_outputs, quant_hidden_state, quant_cell_state = cell(
quant_input.value,
quant_hidden_state.value,
quant_cell_state.value,
quant_weight_ii=quant_weight_ii.value,
quant_weight_if=quant_weight_if.value,
quant_weight_ic=quant_weight_ic.value,
quant_weight_io=quant_weight_io.value,
quant_weight_hi=quant_weight_hi.value,
quant_weight_hf=quant_weight_hf.value,
quant_weight_hc=quant_weight_hc.value,
quant_weight_ho=quant_weight_ho.value,
quant_input_value,
_get_dequantize_tensor(quant_hidden_state),
_get_dequantize_tensor(quant_cell_state),
quant_weight_ii=_get_dequantize_tensor(quant_weight_ii),
quant_weight_if=_get_dequantize_tensor(quant_weight_if),
quant_weight_ic=_get_dequantize_tensor(quant_weight_ic),
quant_weight_io=_get_dequantize_tensor(quant_weight_io),
quant_weight_hi=_get_dequantize_tensor(quant_weight_hi),
quant_weight_hf=_get_dequantize_tensor(quant_weight_hf),
quant_weight_hc=_get_dequantize_tensor(quant_weight_hc),
quant_weight_ho=_get_dequantize_tensor(quant_weight_ho),
quant_bias_input=quant_bias_input,
quant_bias_forget=quant_bias_forget,
quant_bias_cell=quant_bias_cell,
Expand Down
Loading

0 comments on commit fd25826

Please sign in to comment.