From fd258268f3dd86b249ba10fbeb6a35578cd36af8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 11 Sep 2023 14:11:55 +0100 Subject: [PATCH] Feat (quant_tensor): refactor to store int values --- ...1_quant_tensor_quant_conv2d_overview.ipynb | 540 ++++++++---------- notebooks/02_quant_activation_overview.ipynb | 200 +++---- notebooks/03_anatomy_of_a_quantizer.ipynb | 494 ++++++++-------- notebooks/Brevitas_TVMCon2021.ipynb | 426 ++++++++------ src/brevitas/core/quant/binary.py | 4 +- src/brevitas/core/quant/int_base.py | 7 +- src/brevitas/core/stats/stats_op.py | 2 +- src/brevitas/graph/calibrate.py | 2 +- src/brevitas/nn/mixin/base.py | 30 +- src/brevitas/nn/quant_layer.py | 74 ++- src/brevitas/nn/quant_mha.py | 6 +- src/brevitas/nn/quant_rnn.py | 65 ++- src/brevitas/nn/utils.py | 4 +- src/brevitas/proxy/parameter_quant.py | 4 +- src/brevitas/proxy/runtime_quant.py | 10 +- src/brevitas/quant_tensor/__init__.py | 223 +++++--- src/brevitas/quant_tensor/torch_handler.py | 2 +- .../ptq/ptq_evaluate.py | 9 +- tests/brevitas/graph/test_calibration.py | 3 +- 19 files changed, 1093 insertions(+), 1012 deletions(-) diff --git a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb index 2e9ef9179..a4191eaa6 100644 --- a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb +++ b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb @@ -22,7 +22,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/user/.local/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, @@ -39,14 +39,22 @@ " padding: Union[int, Tuple[int, int]] = 0,\n", " dilation: Union[int, Tuple[int, int]] = 1,\n", " groups: int = 1,\n", + " padding_mode: str = 'zeros',\n", " bias: bool = True,\n", - " padding_type: str = 'standard',\n", " weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,\n", " bias_quant: Optional[BiasQuantType] = None,\n", " input_quant: Optional[ActQuantType] = None,\n", " output_quant: Optional[ActQuantType] = None,\n", " return_quant_tensor: bool = False,\n", + " device: Optional[torch.device] = None,\n", + " dtype: Optional[torch.dtype] = None,\n", " **kwargs) -> None:\n", + " # avoid an init error in the super class by setting padding to 0\n", + " if padding_mode == 'zeros' and padding == 'same' and stride > 1:\n", + " padding = 0\n", + " is_same_padded_strided = True\n", + " else:\n", + " is_same_padded_strided = False\n", " Conv2d.__init__(\n", " self,\n", " in_channels=in_channels,\n", @@ -54,9 +62,12 @@ " kernel_size=kernel_size,\n", " stride=stride,\n", " padding=padding,\n", + " padding_mode=padding_mode,\n", " dilation=dilation,\n", " groups=groups,\n", - " bias=bias)\n", + " bias=bias,\n", + " device=device,\n", + " dtype=dtype)\n", " QuantWBIOL.__init__(\n", " self,\n", " weight_quant=weight_quant,\n", @@ -65,9 +76,7 @@ " output_quant=output_quant,\n", " return_quant_tensor=return_quant_tensor,\n", " **kwargs)\n", - " assert self.padding_mode == 'zeros'\n", - " assert not (padding_type == 'same' and padding != 0)\n", - " self.padding_type = padding_type\n", + " self.is_same_padded_strided = is_same_padded_strided\n", "\n", "```" ], @@ -149,20 +158,28 @@ "scrolled": true }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/_tensor.py:1255: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525553989/work/c10/core/TensorImpl.h:1758.)\n", + " return super(Tensor, self).rename(names)\n" + ] + }, { "data": { "text/plain": [ - "tensor([[[[-0.2594, 0.5392, 0.5916],\n", - " [ 0.3493, 0.6813, 0.2499],\n", - " [ 1.3732, 0.1229, -0.0084]],\n", + "tensor([[[[-0.1833, -0.6555, -0.2173],\n", + " [ 0.9269, 0.4396, 0.9834],\n", + " [-1.1737, 0.1799, -0.0602]],\n", "\n", - " [[ 0.0031, -0.1702, 0.1069],\n", - " [-0.8181, -0.8056, 0.0385],\n", - " [-0.4738, 0.0589, 0.1278]],\n", + " [[ 0.8093, -0.4485, 0.1675],\n", + " [ 0.4063, -0.5705, -0.5371],\n", + " [-0.4814, -0.3083, -0.0494]],\n", "\n", - " [[-0.1718, -0.1162, -0.1526],\n", - " [-0.9903, -0.3541, 0.1645],\n", - " [ 0.0557, -0.4458, -0.2080]]]], grad_fn=)" + " [[ 1.7889, 0.0758, 0.4831],\n", + " [ 0.5868, -0.3806, 0.2652],\n", + " [-0.6864, 1.2178, 0.3697]]]], grad_fn=)" ] }, "execution_count": 4, @@ -234,31 +251,31 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.0790, 0.0503, -0.0934],\n", - " [-0.1149, -0.1903, -0.1329],\n", - " [-0.1813, 0.0108, 0.0593]],\n", + "QuantTensor(int_value=tensor([[[[ -3., 40., 92.],\n", + " [ -92., -52., -99.],\n", + " [ 60., 121., 8.]],\n", "\n", - " [[ 0.0970, -0.0215, -0.0144],\n", - " [ 0.2280, 0.1239, -0.0090],\n", - " [ 0.1957, -0.2011, -0.0108]]],\n", + " [[ 2., 90., 123.],\n", + " [ -96., 124., -81.],\n", + " [ -1., 40., 12.]]],\n", "\n", "\n", - " [[[-0.0018, -0.1957, 0.1993],\n", - " [-0.0359, 0.1778, -0.1400],\n", - " [ 0.0916, 0.1059, 0.2173]],\n", + " [[[-125., 87., 6.],\n", + " [ 20., 27., -121.],\n", + " [-104., 44., -115.]],\n", "\n", - " [[-0.1670, 0.1939, -0.2191],\n", - " [-0.0215, 0.1688, -0.1383],\n", - " [-0.0449, -0.1185, 0.1742]]],\n", + " [[ -46., 24., -111.],\n", + " [ -60., -58., 75.],\n", + " [ 64., -7., 116.]]],\n", "\n", "\n", - " [[[-0.0808, -0.1652, -0.0233],\n", - " [-0.0700, 0.0467, -0.0485],\n", - " [ 0.1059, 0.1418, 0.1077]],\n", + " [[[-110., 22., 17.],\n", + " [ 111., 121., 65.],\n", + " [-127., 60., -118.]],\n", "\n", - " [[-0.0593, 0.0108, 0.0036],\n", - " [-0.1508, 0.0808, 0.1616],\n", - " [ 0.0144, -0.0287, -0.1365]]]], grad_fn=), scale=tensor(0.0018, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[-118., 76., 108.],\n", + " [ -87., 81., 119.],\n", + " [ 69., -95., 39.]]]], grad_fn=), scale=tensor(0.0018, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 6, @@ -325,15 +342,15 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor(0.0173, grad_fn=)\n", - "tensor(0.0307, grad_fn=)\n" + "tensor(0.0211, grad_fn=)\n", + "tensor(0.0162, grad_fn=)\n" ] } ], @@ -361,34 +378,33 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.9489, -0.9111, -0.0536, 0.5788, 0.3645],\n", - " [ 0.3401, 1.4325, 0.6498, 0.6411, -1.4390],\n", - " [-1.9029, 0.7012, 0.1591, 1.9235, 0.5883],\n", - " [-2.7258, 2.5330, 0.9165, -0.0820, 3.4148],\n", - " [-0.3651, 1.0164, 0.9567, -0.2758, -1.1376]],\n", - "\n", - " [[-0.2414, 2.2111, -1.9124, -2.3814, -0.8805],\n", - " [ 1.3191, -0.8965, -0.2048, -3.8113, 1.1142],\n", - " [-0.3381, -0.2238, 1.2661, 0.0068, 0.2567],\n", - " [ 0.0731, -0.4280, 0.0909, 0.0875, -1.6851],\n", - " [-0.7744, -1.4127, -0.8143, 1.3557, -0.2802]]]],\n", - " grad_fn=), scale=tensor(0.0240, grad_fn=), zero_point=tensor(0.), bit_width=tensor(9.), signed_t=tensor(True), training_t=tensor(True))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "QuantTensor(int_value=None, scale=tensor(0.0187, grad_fn=), zero_point=tensor(0.), bit_width=tensor(9.), signed_t=tensor(True), training_t=tensor(True))\n", + "tensor([[[[-0.1106, 1.1945, -0.4972, -2.0968, 0.7175],\n", + " [-2.5901, 0.0588, -0.2014, 2.1486, 1.6435],\n", + " [ 0.9067, -2.5212, 2.2193, 0.2352, -0.8395],\n", + " [-0.8351, 0.6341, -0.5551, 0.1040, -3.3151],\n", + " [-0.8979, -0.7092, 3.8232, 1.0875, 0.3954]],\n", + "\n", + " [[ 1.4363, -1.3973, 1.3249, 2.6914, 0.3660],\n", + " [ 1.5057, 1.8094, 0.5100, -1.6874, 1.9981],\n", + " [ 1.2472, -1.7813, 0.0334, -1.2880, -2.9333],\n", + " [ 0.0180, -1.4298, -2.9978, 0.5494, -1.4548],\n", + " [ 1.6738, -0.3177, -0.3721, -0.1650, -1.1871]]]],\n", + " grad_fn=)\n" + ] } ], "source": [ "out_tensor = out_tensor_0 + out_tensor_1\n", - "out_tensor" + "print(out_tensor)\n", + "print(out_tensor.value)" ] }, { @@ -401,7 +417,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -417,23 +433,23 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[1.5800, 1.0157],\n", - " [1.4445, 0.8577]],\n", + "QuantTensor(int_value=tensor([[[[ 30.0000, 37.0000],\n", + " [124.0000, 34.0000]],\n", "\n", - " [[0.5643, 1.2414],\n", - " [1.0383, 0.9028]],\n", + " [[118.0000, 34.0000],\n", + " [ 73.0000, 23.0000]],\n", "\n", - " [[0.5191, 0.6546],\n", - " [2.1442, 0.5868]]]], grad_fn=), scale=tensor(0.0226, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 46.0000, 30.0000],\n", + " [ 47.0000, 78.0000]]]], grad_fn=), scale=tensor(0.0173, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 108, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -455,29 +471,37 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 13, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/ipykernel_launcher.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525553989/work/torch/csrc/utils/python_arg_parser.cpp:350.)\n", + " \"\"\"Entry point for launching an IPython kernel.\n" + ] + }, { "data": { "text/plain": [ - "tensor([[[[-0.4943, -0.9938, -0.9073, 0.7681],\n", - " [-0.3262, 0.9186, 0.1786, 0.3659],\n", - " [ 0.7489, 0.8946, -0.0451, -0.5594],\n", - " [-0.1346, -0.4943, -0.4770, 0.6951]],\n", + "tensor([[[[ 0.4770, 0.2212, 0.0691, 0.5650],\n", + " [-0.0346, -0.6618, -0.4635, -0.3482],\n", + " [ 0.9730, -0.7245, -0.5881, -0.5287],\n", + " [-0.0863, 0.8857, 0.5287, -0.4498]],\n", "\n", - " [[ 0.0676, 0.5111, 0.4943, 0.8459],\n", - " [-0.8990, -0.9426, 0.0676, -0.7945],\n", - " [-0.9220, 0.0676, -0.5594, 0.6321],\n", - " [-0.0676, 0.7772, 0.7177, -0.4414]],\n", + " [[ 0.9669, 0.5650, -0.6211, -0.4498],\n", + " [-0.2376, 0.6103, 0.5287, 0.2700],\n", + " [-0.6808, 0.8519, 0.2700, -0.5531],\n", + " [-0.0173, 0.8264, 0.3782, -0.1881]],\n", "\n", - " [[ 0.4770, 0.2220, 0.0676, 0.5747],\n", - " [-0.0451, -0.6710, -0.4594, -0.3462],\n", - " [ 0.9729, -0.7177, -0.5896, -0.5276],\n", - " [-0.0900, 0.8852, 0.5276, -0.4414]]]], grad_fn=)" + " [[-0.6211, -0.9764, -0.5993, 0.4770],\n", + " [ 0.5033, 0.6618, -0.1881, -0.6211],\n", + " [-0.8031, 0.1375, 0.5287, 0.8740],\n", + " [-0.6714, 0.6714, -0.5650, 0.8611]]]], grad_fn=)" ] }, - "execution_count": 109, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -497,26 +521,26 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.9693, -0.9431, 0.2459],\n", - " [ 0.5416, 0.9037, -0.5278],\n", - " [-0.6207, -1.3578, -0.4815]],\n", + "QuantTensor(int_value=tensor([[[[-11164., -5891., 11231.],\n", + " [-13986., 3252., -12343.],\n", + " [-13112., 11651., 26235.]],\n", "\n", - " [[ 0.4551, -1.4065, 0.8889],\n", - " [-0.3393, 0.0803, -0.1748],\n", - " [-0.0977, 0.6284, -0.7193]],\n", + " [[ -1595., 17255., -8607.],\n", + " [-17736., 18224., 7286.],\n", + " [ 4118., -7880., 32600.]],\n", "\n", - " [[ 0.3655, 0.7626, -0.2634],\n", - " [-0.3453, 0.3349, 0.1923],\n", - " [ 0.5993, -0.9579, 0.3557]]]], grad_fn=), scale=tensor([[[[3.2208e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 8555., 22742., -779.],\n", + " [ -5418., 16261., 34914.],\n", + " [ 11799., -11921., 6282.]]]], grad_fn=), scale=tensor([[[[3.1958e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 110, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -533,20 +557,9 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": 15, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 111, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -569,26 +582,26 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 5.7000e-03, 2.5000e-03, -1.2400e-02, -7.2000e-03, 3.7000e-03],\n", - " [-2.3000e-03, 7.0000e-04, -1.2700e-02, 5.2000e-03, 4.0000e-04],\n", - " [-7.9000e-03, 9.5000e-03, 6.6000e-03, 5.4000e-03, 2.5000e-03],\n", - " [ 1.1100e-02, 2.4000e-03, 1.0000e-02, -3.7000e-03, 7.2000e-03],\n", - " [-1.1500e-02, -5.8000e-03, -9.3000e-03, 1.0000e-02, 3.5000e-03]],\n", + "QuantTensor(int_value=tensor([[[[ 72, -37, 77, -24, -89],\n", + " [-120, -81, 72, -113, -97],\n", + " [ -10, 101, 38, -119, 69],\n", + " [ 83, 1, -69, 39, -54],\n", + " [ 113, -60, 97, 0, 109]],\n", "\n", - " [[-6.8000e-03, 1.1500e-02, -1.0600e-02, -1.5000e-03, -1.9000e-03],\n", - " [ 2.9000e-03, 9.5000e-03, 7.2000e-03, -3.7000e-03, 7.7000e-03],\n", - " [-2.4000e-03, -8.9000e-03, -1.2000e-02, -8.1000e-03, 7.2000e-03],\n", - " [-1.1300e-02, -9.7000e-03, -1.0000e-03, 1.0100e-02, 3.8000e-03],\n", - " [-1.1900e-02, 6.9000e-03, 8.3000e-03, 1.0000e-04, -6.9000e-03]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[-109, 114, -64, 92, 71],\n", + " [ -6, 92, -85, 50, 65],\n", + " [ -83, -12, 74, 92, -6],\n", + " [ -21, 95, 3, -29, -65],\n", + " [-118, -48, 54, -25, 9]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 112, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -600,9 +613,8 @@ "bit_width = 8\n", "zero_point = 0.\n", "int_value = torch.randint(low=- 2 ** (bit_width - 1), high=2 ** (bit_width - 1) - 1, size=(1, 2, 5, 5))\n", - "quant_value = (int_value - zero_point) * scale\n", "quant_tensor_input = QuantTensor(\n", - " quant_value, \n", + " int_value, \n", " scale=torch.tensor(scale), \n", " zero_point=torch.tensor(zero_point), \n", " bit_width=torch.tensor(float(bit_width)),\n", @@ -613,20 +625,9 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 17, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 113, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert quant_tensor_input.is_valid" ] @@ -642,26 +643,26 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.0085, 0.0066, 0.0050],\n", - " [-0.0038, -0.0009, -0.0115],\n", - " [-0.0055, -0.0037, 0.0009]],\n", + "QuantTensor(int_value=tensor([[[[-1., 3., -1.],\n", + " [-1., 3., -4.],\n", + " [-1., -2., -2.]],\n", "\n", - " [[ 0.0015, -0.0027, -0.0079],\n", - " [-0.0034, -0.0060, 0.0043],\n", - " [-0.0008, 0.0052, -0.0033]],\n", + " [[-2., 2., 6.],\n", + " [ 4., 3., 1.],\n", + " [ 1., -0., 1.]],\n", "\n", - " [[-0.0015, 0.0082, -0.0038],\n", - " [-0.0021, 0.0004, -0.0054],\n", - " [-0.0021, -0.0079, 0.0013]]]], grad_fn=), scale=tensor([[[[1.8448e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[-2., -1., 0.],\n", + " [-2., 1., 1.],\n", + " [-1., 2., -2.]]]], grad_fn=), scale=tensor([[[[1.8307e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 114, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -675,20 +676,9 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 19, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 115, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -702,26 +692,26 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.0035, -0.0037, -0.0050],\n", - " [ 0.0010, -0.0051, -0.0027],\n", - " [-0.0010, 0.0047, 0.0017]],\n", + "QuantTensor(int_value=tensor([[[[-41695., 22899., -6250.],\n", + " [-18837., 44913., -16334.],\n", + " [ 2686., -14154., -4355.]],\n", "\n", - " [[ 0.0021, 0.0002, 0.0027],\n", - " [ 0.0028, 0.0002, -0.0044],\n", - " [ 0.0008, -0.0052, -0.0024]],\n", + " [[ 12052., -11925., 19953.],\n", + " [ 6630., -9381., -13156.],\n", + " [ -5705., -8407., 22889.]],\n", "\n", - " [[ 0.0010, -0.0052, -0.0011],\n", - " [-0.0018, 0.0024, 0.0011],\n", - " [-0.0001, 0.0039, 0.0035]]]], grad_fn=), scale=tensor([[[[1.7410e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[ -5578., 27082., 14164.],\n", + " [ -8009., 12023., -22616.],\n", + " [ 20760., -1695., 15229.]]]], grad_fn=), scale=tensor([[[[1.7393e-11]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 116, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -741,26 +731,26 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.2111, 0.4060, 0.3654],\n", - " [-0.7876, 0.8119, -0.9825],\n", - " [-0.5115, 0.3979, -0.3248]],\n", + "QuantTensor(int_value=tensor([[[[ -22., -50., 4.],\n", + " [ -53., -26., -23.],\n", + " [ -60., 2., -57.]],\n", "\n", - " [[ 0.3816, 0.0568, -0.0812],\n", - " [ 1.0312, -0.7876, 0.8038],\n", - " [-0.3491, -0.4141, 0.0650]],\n", + " [[ 14., 85., -128.],\n", + " [ -63., 46., -40.],\n", + " [ 18., -53., -13.]],\n", "\n", - " [[-0.5846, -0.4222, -0.0731],\n", - " [-0.7389, 0.5034, -0.2517],\n", - " [-0.1624, -0.4385, 0.7308]]]], grad_fn=), scale=tensor(0.0081, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 113., 25., -30.],\n", + " [ -45., -51., -21.],\n", + " [ 67., 50., 14.]]]], grad_fn=), scale=tensor(0.0096, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 117, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -777,20 +767,9 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 22, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 118, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -816,7 +795,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 24, "metadata": { "tags": [ "raises-exception" @@ -830,12 +809,12 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_48365/2280634207.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0min_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkernel_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m bias_quant=Int8Bias, return_quant_tensor=True)\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mbias_quant_conv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mquant_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/proxy/parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mimpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input scale required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input bit-width required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_589993/2280634207.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0min_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkernel_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m bias_quant=Int8Bias, return_quant_tensor=True)\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mbias_quant_conv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 198\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 339\u001b[0;31m \u001b[0mquant_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 340\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 341\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0mimpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 206\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input scale required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 207\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input bit-width required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mRuntimeError\u001b[0m: Input scale required" ] } @@ -858,26 +837,26 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.0005, 0.0043, -0.0004],\n", - " [ 0.0005, 0.0106, 0.0012],\n", - " [ 0.0021, 0.0007, -0.0050]],\n", + "QuantTensor(int_value=tensor([[[[-131., -126., -126.],\n", + " [-129., -128., -126.],\n", + " [-131., -126., -128.]],\n", "\n", - " [[-0.0067, -0.0035, -0.0059],\n", - " [-0.0050, -0.0015, -0.0039],\n", - " [ 0.0015, 0.0028, -0.0008]],\n", + " [[ 128., 127., 126.],\n", + " [ 128., 128., 126.],\n", + " [ 126., 126., 127.]],\n", "\n", - " [[-0.0051, -0.0050, 0.0060],\n", - " [-0.0015, 0.0037, 0.0071],\n", - " [ 0.0067, 0.0035, -0.0071]]]], grad_fn=), scale=tensor([[[[1.8108e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[-131., -125., -128.],\n", + " [-126., -126., -130.],\n", + " [-127., -131., -130.]]]], grad_fn=), scale=tensor([[[[1.8528e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 120, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -895,26 +874,26 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.3825, 0.1371, 0.9135],\n", - " [-0.2016, 0.7495, -0.4071],\n", - " [-0.0755, 0.5283, 0.2388]],\n", + "QuantTensor(int_value=tensor([[[[-14802., 5940., -11690.],\n", + " [ 1056., -180., -40225.],\n", + " [-10879., 4592., -15348.]],\n", "\n", - " [[ 0.0788, -0.3802, -0.2234],\n", - " [ 0.8678, -0.5546, 0.4408],\n", - " [-0.6788, 0.4422, 0.3007]],\n", + " [[ 34887., 24540., -13335.],\n", + " [ 16723., -21359., 5380.],\n", + " [ -5614., -7566., 4970.]],\n", "\n", - " [[ 0.4412, -0.3205, 1.0033],\n", - " [-0.0083, -0.3295, -0.2076],\n", - " [ 0.4417, -0.1046, -0.3493]]]], grad_fn=), scale=tensor([[[[3.8610e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[-50260., 31346., 21783.],\n", + " [ 22957., -6246., -582.],\n", + " [ 22654., 25542., -16811.]]]], grad_fn=), scale=tensor([[[[2.9050e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 121, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -928,26 +907,26 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.0036, 0.0024, -0.0033],\n", - " [ 0.0050, 0.0080, -0.0014],\n", - " [-0.0036, -0.0080, -0.0029]],\n", + "QuantTensor(int_value=tensor([[[[ -8843., -19959., 1488.],\n", + " [-31183., 27010., 31805.],\n", + " [ 24630., 31302., -28504.]],\n", "\n", - " [[ 0.0083, -0.0093, 0.0048],\n", - " [ 0.0035, 0.0015, -0.0011],\n", - " [-0.0003, 0.0067, 0.0013]],\n", + " [[ -2123., 7453., -10084.],\n", + " [ 31452., -42144., 13385.],\n", + " [-30325., 5258., 18670.]],\n", "\n", - " [[-0.0009, -0.0019, 0.0039],\n", - " [ 0.0010, 0.0056, -0.0037],\n", - " [ 0.0091, -0.0095, 0.0054]]]], grad_fn=), scale=tensor([[[[1.8384e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 8588., -1254., -39190.],\n", + " [ 8470., -22817., -26758.],\n", + " [-19270., -5130., 45632.]]]], grad_fn=), scale=tensor([[[[1.7377e-11]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 122, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -967,7 +946,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 28, "metadata": { "tags": [ "raises-exception" @@ -981,12 +960,12 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_48365/2990591641.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0min_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkernel_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0moutput_bias_quant_conv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mquant_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/proxy/parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mimpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input scale required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input bit-width required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_589993/2990591641.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0min_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkernel_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0moutput_bias_quant_conv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 198\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 339\u001b[0;31m \u001b[0mquant_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 340\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 341\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0mimpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 206\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input scale required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 207\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input bit-width required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mRuntimeError\u001b[0m: Input scale required" ] } @@ -1007,26 +986,26 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[[ 0.2152, 0.8346, 0.0746],\n", - " [-0.0738, -0.5212, 0.1019],\n", - " [-0.6004, 0.1500, -0.1453]],\n", + "tensor([[[[-0.6938, 0.0069, 0.1652],\n", + " [-0.4801, -0.8120, 0.5233],\n", + " [ 0.4159, 0.4662, 0.2565]],\n", "\n", - " [[-1.1551, -1.3458, -0.1312],\n", - " [ 0.2502, -0.5267, 0.2412],\n", - " [-0.3556, -0.3289, -0.2276]],\n", + " [[ 0.3206, -0.5500, -0.5254],\n", + " [ 0.1864, 1.0210, -0.3706],\n", + " [-0.1159, 0.6967, -0.0437]],\n", "\n", - " [[-0.4599, -0.6094, 0.4682],\n", - " [-0.5064, -0.6768, -0.6638],\n", - " [ 0.0066, -0.3581, 0.2359]]]], grad_fn=)" + " [[-0.6209, -0.5257, -0.6592],\n", + " [ 0.6389, 0.2658, 0.4542],\n", + " [-0.3761, -0.7776, -0.2897]]]], grad_fn=)" ] }, - "execution_count": 124, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1051,30 +1030,30 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.6879, -0.6632, -0.2411],\n", - " [ 0.2064, -0.7371, 0.3910],\n", - " [ 0.9533, 0.2994, 0.6546]],\n", + "QuantTensor(int_value=tensor([[[[ -8448., 34284., 23327.],\n", + " [ 3449., 3914., -2553.],\n", + " [-20566., 22916., 20011.]],\n", "\n", - " [[-0.4684, -0.4495, -0.5021],\n", - " [ 0.5738, 0.4199, -0.3380],\n", - " [ 0.6218, -0.0408, -0.8483]],\n", + " [[ 15277., -19236., -7557.],\n", + " [ 4996., -14594., -22873.],\n", + " [ 17500., 2874., 22446.]],\n", "\n", - " [[-0.5625, 0.1837, -1.0575],\n", - " [-1.2816, -0.4993, -0.3409],\n", - " [ 0.4556, -1.4269, 0.5369]]]], grad_fn=), scale=tensor([[[[3.0975e-05]]]], grad_fn=), zero_point=tensor([[[[ 1276.0774]],\n", + " [[ 10549., -14056., 34549.],\n", + " [ -5044., 22675., 7862.],\n", + " [-21969., 13473., -1175.]]]], grad_fn=), scale=tensor([[[[2.7130e-05]]]], grad_fn=), zero_point=tensor([[[[ 6313.4204]],\n", "\n", - " [[-3152.4585]],\n", + " [[-2667.2593]],\n", "\n", - " [[ 7320.2324]]]], grad_fn=), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[-5507.9629]]]], grad_fn=), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 125, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -1089,20 +1068,9 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 31, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 126, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -1116,26 +1084,26 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[[ 0.8357, 0.0733, 0.9527],\n", - " [ 0.1803, 0.2154, 0.7598],\n", - " [ 1.1121, -0.8728, 1.0039]],\n", + "tensor([[[[ 0.0650, 0.2496, -1.2857],\n", + " [ 1.0231, 0.0516, 0.7592],\n", + " [ 0.5882, -0.7619, 0.7604]],\n", "\n", - " [[ 0.7917, 1.0063, 0.6516],\n", - " [-0.1852, -0.7263, 0.0956],\n", - " [-0.1876, 0.2747, -0.1617]],\n", + " [[-0.6307, 0.1476, 1.0949],\n", + " [-0.1488, 0.0472, 0.0097],\n", + " [-0.2861, 0.0266, -0.2970]],\n", "\n", - " [[ 0.8299, 0.9934, -0.3821],\n", - " [ 0.4865, 0.9309, -0.7924],\n", - " [-0.4201, 0.2343, 0.1532]]]], grad_fn=)" + " [[ 0.0580, 1.2994, 0.3841],\n", + " [ 0.2056, 0.0496, -0.7915],\n", + " [ 0.4698, -0.8724, -0.0405]]]], grad_fn=)" ] }, - "execution_count": 127, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -1171,7 +1139,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.15" + "version": "3.7.16" }, "vscode": { "interpreter": { diff --git a/notebooks/02_quant_activation_overview.ipynb b/notebooks/02_quant_activation_overview.ipynb index 4d2ac73d1..e8afe3b64 100644 --- a/notebooks/02_quant_activation_overview.ipynb +++ b/notebooks/02_quant_activation_overview.ipynb @@ -26,14 +26,14 @@ }, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/_tensor.py:1255: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525553989/work/c10/core/TensorImpl.h:1758.)\n", + " return super(Tensor, self).rename(names)\n" + ] } ], "source": [ @@ -68,18 +68,7 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "torch.manual_seed(0)\n", "input_output_quant_conv = QuantConv2d(\n", @@ -149,18 +138,17 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.4566, -0.5707, -0.5517, 0.5897, 1.5409],\n", - " [ 0.5136, -0.5897, -0.5707, 0.1902, -0.0761],\n", - " [-0.4946, -1.5029, -0.1902, 0.4376, 1.3317],\n", - " [-1.6361, 2.0736, 1.7122, 2.3780, -1.1224],\n", - " [-0.3234, -1.0844, -0.0761, -0.0951, -0.7610]],\n", + "QuantTensor(int_value=tensor([[[[ -24., -30., -29., 31., 81.],\n", + " [ 27., -31., -30., 10., -4.],\n", + " [ -26., -79., -10., 23., 70.],\n", + " [ -86., 109., 90., 125., -59.],\n", + " [ -17., -57., -4., -5., -40.]],\n", "\n", - " [[-1.5980, 0.0190, -0.7419, 0.1902, 0.6278],\n", - " [ 0.6468, -0.2473, -0.5327, 1.1605, 0.4376],\n", - " [-0.7990, -1.2936, -0.7419, -1.3127, -0.2283],\n", - " [-2.4351, -0.0761, 0.2283, 0.7990, -0.1902],\n", - " [-0.3615, -1.2175, -0.6278, -0.4566, 1.9214]]]],\n", - " grad_fn=), scale=tensor(0.0190, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[ -84., 1., -39., 10., 33.],\n", + " [ 34., -13., -28., 61., 23.],\n", + " [ -42., -68., -39., -69., -12.],\n", + " [-128., -4., 12., 42., -10.],\n", + " [ -19., -64., -33., -24., 101.]]]], grad_fn=), scale=tensor(0.0190, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 4, @@ -178,18 +166,7 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -241,18 +218,17 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.4566, -0.5707, -0.5517, 0.5897, 1.5409],\n", - " [ 0.5136, -0.5897, -0.5707, 0.1902, -0.0761],\n", - " [-0.4946, -1.5029, -0.1902, 0.4376, 1.3317],\n", - " [-1.6361, 2.0736, 1.7122, 2.3780, -1.1224],\n", - " [-0.3234, -1.0844, -0.0761, -0.0951, -0.7610]],\n", + "QuantTensor(int_value=tensor([[[[ -24., -30., -29., 31., 81.],\n", + " [ 27., -31., -30., 10., -4.],\n", + " [ -26., -79., -10., 23., 70.],\n", + " [ -86., 109., 90., 125., -59.],\n", + " [ -17., -57., -4., -5., -40.]],\n", "\n", - " [[-1.5980, 0.0190, -0.7419, 0.1902, 0.6278],\n", - " [ 0.6468, -0.2473, -0.5327, 1.1605, 0.4376],\n", - " [-0.7990, -1.2936, -0.7419, -1.3127, -0.2283],\n", - " [-2.4351, -0.0761, 0.2283, 0.7990, -0.1902],\n", - " [-0.3615, -1.2175, -0.6278, -0.4566, 1.9214]]]],\n", - " grad_fn=), scale=tensor(0.0190, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[ -84., 1., -39., 10., 33.],\n", + " [ 34., -13., -28., 61., 23.],\n", + " [ -42., -68., -39., -69., -12.],\n", + " [-128., -4., 12., 42., -10.],\n", + " [ -19., -64., -33., -24., 101.]]]], grad_fn=), scale=tensor(0.0190, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 7, @@ -281,17 +257,17 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[0.0000, 0.0000, 0.0000, 0.5974, 1.5402],\n", - " [0.5041, 0.0000, 0.0000, 0.1867, 0.0000],\n", - " [0.0000, 0.0000, 0.0000, 0.4481, 1.3255],\n", - " [0.0000, 2.0817, 1.7083, 2.3804, 0.0000],\n", - " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n", + "QuantTensor(int_value=tensor([[[[ 0., 0., 0., 64., 165.],\n", + " [ 54., 0., 0., 20., 0.],\n", + " [ 0., 0., 0., 48., 142.],\n", + " [ 0., 223., 183., 255., 0.],\n", + " [ 0., 0., 0., 0., 0.]],\n", "\n", - " [[0.0000, 0.0187, 0.0000, 0.1867, 0.6254],\n", - " [0.6348, 0.0000, 0.0000, 1.1668, 0.4387],\n", - " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", - " [0.0000, 0.0000, 0.2334, 0.7935, 0.0000],\n", - " [0.0000, 0.0000, 0.0000, 0.0000, 1.9230]]]], grad_fn=), scale=tensor(0.0093, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" + " [[ 0., 2., 0., 20., 67.],\n", + " [ 68., 0., 0., 125., 47.],\n", + " [ 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 25., 85., 0.],\n", + " [ 0., 0., 0., 0., 206.]]]], grad_fn=), scale=tensor(0.0093, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" ] }, "execution_count": 8, @@ -342,7 +318,7 @@ { "data": { "text/plain": [ - "QuantTensor(value=(tensor([[[[0.3878, 0.3611, 0.3655, 0.6433, 0.8236],\n", + "tensor([[[[0.3878, 0.3611, 0.3655, 0.6433, 0.8236],\n", " [0.6257, 0.3567, 0.3611, 0.5474, 0.4810],\n", " [0.3788, 0.1820, 0.4526, 0.6077, 0.7911],\n", " [0.1630, 0.8883, 0.8471, 0.9151, 0.2456],\n", @@ -353,7 +329,7 @@ " [0.3102, 0.2152, 0.3226, 0.2120, 0.4432],\n", " [0.0805, 0.4810, 0.5568, 0.6898, 0.4526],\n", " [0.4106, 0.2284, 0.3480, 0.3878, 0.8723]]]],\n", - " grad_fn=), None, None, None), scale=None, zero_point=None, bit_width=None, signed_t=None, training_t=tensor(True))" + " grad_fn=)" ] }, "execution_count": 10, @@ -363,7 +339,6 @@ ], "source": [ "from brevitas.nn import QuantSigmoid\n", - "\n", "return_disabled_quant_sigmoid = QuantSigmoid(act_quant=None, return_quant_tensor=True)\n", "sigmoid_out_tensor = return_disabled_quant_sigmoid(out_tensor)\n", "sigmoid_out_tensor" @@ -373,20 +348,10 @@ "cell_type": "code", "execution_count": 11, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "assert not sigmoid_out_tensor.is_valid" + "from brevitas.quant_tensor import QuantTensor\n", + "assert not isinstance(sigmoid_out_tensor, QuantTensor)" ] }, { @@ -406,7 +371,7 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[0.0000, 0.0000, 0.0000, 0.5854, 1.5485],\n", + "QuantTensor(int_value=tensor([[[[0.0000, 0.0000, 0.0000, 0.5854, 1.5485],\n", " [0.5099, 0.0000, 0.0000, 0.1888, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.4532, 1.3219],\n", " [0.0000, 2.0772, 1.6996, 2.3794, 0.0000],\n", @@ -416,7 +381,7 @@ " [0.6421, 0.0000, 0.0000, 1.1708, 0.4343],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.2266, 0.7931, 0.0000],\n", - " [0.0000, 0.0000, 0.0000, 0.0000, 1.9262]]]], grad_fn=), scale=tensor(0.0189, grad_fn=), zero_point=tensor(129., grad_fn=), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" + " [0.0000, 0.0000, 0.0000, 0.0000, 1.9262]]]], grad_fn=), scale=tensor(0.0189, grad_fn=), zero_point=tensor(129., grad_fn=), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" ] }, "execution_count": 12, @@ -426,10 +391,10 @@ ], "source": [ "from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat\n", - "\n", "shifted_quant_identity = QuantIdentity(act_quant=ShiftedUint8ActPerTensorFloat, return_quant_tensor=True)\n", "return_disabled_quant_relu = QuantReLU(act_quant=None, return_quant_tensor=True)\n", - "return_disabled_quant_relu(shifted_quant_identity(inp))" + "out = shifted_quant_identity(inp)\n", + "return_disabled_quant_relu(out)" ] }, { @@ -555,18 +520,7 @@ "cell_type": "code", "execution_count": 16, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "out1_train = quant_identity(inp1)\n", "out2_train = quant_identity(inp2)\n", @@ -577,18 +531,7 @@ "cell_type": "code", "execution_count": 17, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "quant_identity.eval()\n", "out1_eval = quant_identity(inp1)\n", @@ -617,19 +560,19 @@ "evalue": "'Int8ActPerTensorFloatMinMaxInit' can not resolve attribute 'max_val' while building 'scaling_init_impl'", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mDependencyError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mbrevitas\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mQuantHardTanh\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mQuantHardTanh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_activation.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_quant, input_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 117\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 118\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 119\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 121\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_layer.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[0mpassthrough_act\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 79\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 80\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 81\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\act.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, act_quant, **kwargs)\u001b[0m\n\u001b[0;32m 157\u001b[0m \u001b[0mproxy_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'act_'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 158\u001b[0m \u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m''\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 159\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 160\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 161\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\base.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mfilter_kwargs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[0mquant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mproxy_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mproxy_protocol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[0;32m 108\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 109\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 110\u001b[1;33m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mActQuantProxyFromInjector\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 111\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_passthrough_act\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_passthrough_act\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector, export_mode, export_handler)\u001b[0m\n\u001b[0;32m 74\u001b[0m \u001b[1;31m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 75\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 76\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_tracked_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 77\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_handler\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_mode\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36madd_tracked_module\u001b[1;34m(self, module)\u001b[0m\n\u001b[0;32m 130\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 131\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate_tracked_modules\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 132\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 133\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Trying to add None as a parent module.\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36minit_tensor_quant\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 121\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 122\u001b[1;33m \u001b[0mtensor_quant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 123\u001b[0m \u001b[0mact_impl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 124\u001b[0m \u001b[0mis_act_enabled\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_act_enabled\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtensor_quant\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - " \u001b[1;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[1;31mDependencyError\u001b[0m: 'Int8ActPerTensorFloatMinMaxInit' can not resolve attribute 'max_val' while building 'scaling_init_impl'" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mDependencyError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_624727/3230415679.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mbrevitas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mQuantHardTanh\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mQuantHardTanh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_activation.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, act_quant, input_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mact_quant\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_quant_tensor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 85\u001b[0;31m **kwargs)\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0mQuantLayerMixin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mQuantInputMixin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_quant\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m \u001b[0mQuantNonLinearActMixin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mact_impl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpassthrough_act\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/act.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, act_impl, passthrough_act, act_quant, act_proxy_prefix, act_kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0mnone_quant_injector\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mNoneActQuant\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mprefixed_kwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m **kwargs)\n\u001b[0m\u001b[1;32m 127\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/base.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mfilter_kwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs_prefix\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0mquant\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproxy_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproxy_protocol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_layer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 89\u001b[0;31m \u001b[0mQuantProxyFromInjector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_layer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 90\u001b[0m \u001b[0mActQuantProxyProtocol\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_passthrough_act\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_is_passthrough_act\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0;31m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_tracked_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 83\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisable_quant\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py\u001b[0m in \u001b[0;36madd_tracked_module\u001b[0;34m(self, module)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_tracked_modules\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 120\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_tensor_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 121\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Trying to add None as a parent module.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py\u001b[0m in \u001b[0;36minit_tensor_quant\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minit_tensor_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m \u001b[0mtensor_quant\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 103\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'act_impl'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0mact_impl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/inject/__init__.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(cls, attrname)\u001b[0m\n\u001b[1;32m 127\u001b[0m message = \"{!r} can not resolve attribute {!r}\".format(\n\u001b[1;32m 128\u001b[0m cls.__name__, current_attr)\n\u001b[0;32m--> 129\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mDependencyError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 130\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mmarker\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattribute\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhave_defaults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDependencyError\u001b[0m: 'Int8ActPerTensorFloatMinMaxInit' can not resolve attribute 'max_val' while building 'scaling_init_impl'" ] } ], @@ -666,18 +609,7 @@ "cell_type": "code", "execution_count": 20, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "out1_train = quant_hard_tanh(inp1)\n", "quant_hard_tanh.eval()\n", @@ -711,7 +643,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.16" } }, "nbformat": 4, diff --git a/notebooks/03_anatomy_of_a_quantizer.ipynb b/notebooks/03_anatomy_of_a_quantizer.ipynb index 2055a1714..b6dac057f 100644 --- a/notebooks/03_anatomy_of_a_quantizer.ipynb +++ b/notebooks/03_anatomy_of_a_quantizer.ipynb @@ -21,6 +21,14 @@ "execution_count": 1, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, { "data": { "text/plain": [ @@ -191,7 +199,7 @@ " @brevitas.jit.script_method\n", " def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n", " scale = self.scaling_impl(x)\n", - " y = binary_sign_ste(x) * scale\n", + " y = binary_sign_ste(x) #* scale\n", " y = self.delay_wrapper(x, y)\n", " return y, scale, self.zero_point(), self.bit_width()\n", "\n", @@ -247,10 +255,10 @@ { "data": { "text/plain": [ - "(tensor([[ 0.1000, 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000, -0.1000]], grad_fn=),\n", + "(tensor([[-1., -1., 1., 1.],\n", + " [-1., -1., 1., 1.],\n", + " [ 1., 1., -1., -1.],\n", + " [-1., -1., -1., -1.]]),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -292,10 +300,10 @@ { "data": { "text/plain": [ - "(tensor([[-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, 0.1000, 0.1000]], grad_fn=),\n", + "(tensor([[ 1., 1., -1., -1.],\n", + " [-1., -1., -1., -1.],\n", + " [ 1., -1., -1., 1.],\n", + " [ 1., -1., 1., -1.]]),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -342,10 +350,10 @@ { "data": { "text/plain": [ - "(tensor([[ 1., -1., 1., 1.],\n", + "(tensor([[ 1., -1., 1., -1.],\n", " [ 1., 1., -1., 1.],\n", - " [ 1., 1., 1., -1.],\n", - " [-1., 1., -1., -1.]], grad_fn=),\n", + " [-1., -1., 1., 1.],\n", + " [-1., 1., 1., 1.]]),\n", " tensor(1., grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -379,10 +387,10 @@ { "data": { "text/plain": [ - "(tensor([[ 0.1000, -0.1000, -0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000, 0.1000],\n", - " [-0.1000, 0.1000, -0.1000, 0.1000]], grad_fn=),\n", + "(tensor([[ 1., -1., -1., 1.],\n", + " [ 1., 1., -1., 1.],\n", + " [ 1., -1., -1., 1.],\n", + " [ 1., -1., -1., -1.]]),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -448,30 +456,30 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, -0.1000]],\n", + "QuantTensor(int_value=tensor([[[[ 1., 1., 1.],\n", + " [-1., -1., 1.],\n", + " [ 1., -1., -1.]],\n", "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000]],\n", + " [[-1., 1., 1.],\n", + " [-1., 1., 1.],\n", + " [-1., -1., -1.]],\n", "\n", - " [[ 0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000]]],\n", + " [[-1., 1., 1.],\n", + " [ 1., 1., -1.],\n", + " [-1., -1., 1.]]],\n", "\n", "\n", - " [[[ 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]],\n", + " [[[ 1., 1., 1.],\n", + " [ 1., 1., 1.],\n", + " [ 1., -1., -1.]],\n", "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000]],\n", + " [[-1., 1., -1.],\n", + " [-1., 1., 1.],\n", + " [ 1., -1., -1.]],\n", "\n", - " [[ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=None, training_t=tensor(True))" + " [[ 1., -1., 1.],\n", + " [ 1., 1., 1.],\n", + " [ 1., -1., 1.]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=None, training_t=tensor(True))" ] }, "execution_count": 11, @@ -498,7 +506,19 @@ "cell_type": "code", "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_625769/2912028296.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_valid\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], "source": [ "assert not quant_weight.is_valid" ] @@ -518,30 +538,30 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000]],\n", + "QuantTensor(int_value=tensor([[[[-1., 1., -1.],\n", + " [-1., 1., -1.],\n", + " [-1., 1., -1.]],\n", "\n", - " [[ 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000]],\n", + " [[-1., 1., 1.],\n", + " [-1., 1., 1.],\n", + " [-1., 1., 1.]],\n", "\n", - " [[-0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000],\n", - " [-0.1000, -0.1000, -0.1000]]],\n", + " [[ 1., -1., -1.],\n", + " [-1., 1., 1.],\n", + " [-1., 1., -1.]]],\n", "\n", "\n", - " [[[ 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000]],\n", + " [[[ 1., 1., -1.],\n", + " [-1., 1., 1.],\n", + " [ 1., -1., -1.]],\n", "\n", - " [[-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000]],\n", + " [[-1., 1., -1.],\n", + " [-1., -1., -1.],\n", + " [ 1., 1., 1.]],\n", "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[-1., 1., 1.],\n", + " [ 1., -1., -1.],\n", + " [-1., -1., 1.]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 13, @@ -560,11 +580,11 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "assert signed_quant_weight.is_valid == True" + "assert signed_quant_weight.is_valid" ] }, { @@ -578,39 +598,39 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]],\n", + "QuantTensor(int_value=tensor([[[[ 1., 1., 1.],\n", + " [ 1., -1., 1.],\n", + " [ 1., 1., 1.]],\n", "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000]],\n", + " [[ 1., 1., 1.],\n", + " [ 1., 1., 1.],\n", + " [-1., -1., 1.]],\n", "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000]]],\n", + " [[-1., -1., 1.],\n", + " [-1., 1., -1.],\n", + " [ 1., -1., -1.]]],\n", "\n", "\n", - " [[[-0.1000, -0.1000, -0.1000],\n", - " [-0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, -0.1000]],\n", + " [[[ 1., -1., -1.],\n", + " [ 1., -1., -1.],\n", + " [ 1., -1., -1.]],\n", "\n", - " [[-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]],\n", + " [[-1., 1., 1.],\n", + " [-1., -1., 1.],\n", + " [-1., -1., -1.]],\n", "\n", - " [[ 0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 1., -1., 1.],\n", + " [ 1., -1., 1.],\n", + " [-1., -1., -1.]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 16, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -640,19 +660,27 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/_tensor.py:1255: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525553989/work/c10/core/TensorImpl.h:1758.)\n", + " return super(Tensor, self).rename(names)\n" + ] + }, { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000, -0.1000]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + "QuantTensor(int_value=tensor([[-1., -1., 1., -1.],\n", + " [ 1., 1., 1., 1.],\n", + " [ 1., -1., -1., 1.],\n", + " [ 1., -1., 1., -1.]]), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 17, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -678,19 +706,19 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[ 0.0010, 0.0010, 0.0010, -0.0010],\n", - " [ 0.0010, -0.0010, 0.0010, -0.0010],\n", - " [-0.0010, -0.0010, -0.0010, -0.0010],\n", - " [ 0.0010, 0.0010, 0.0010, 0.0010]], grad_fn=), scale=tensor(0.0010, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + "QuantTensor(int_value=tensor([[-1., -1., 1., 1.],\n", + " [-1., -1., 1., 1.],\n", + " [-1., 1., -1., -1.],\n", + " [ 1., -1., 1., -1.]]), scale=tensor(0.0010, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 18, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -716,7 +744,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -740,7 +768,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": { "scrolled": true }, @@ -748,33 +776,33 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1876, -0.1876, -0.1876],\n", - " [ 0.1876, 0.1876, 0.1876],\n", - " [-0.1876, -0.1876, 0.1876]],\n", + "QuantTensor(int_value=tensor([[[[-1., 1., 1.],\n", + " [-1., -1., -1.],\n", + " [ 1., 1., 1.]],\n", "\n", - " [[-0.1876, -0.1876, 0.1876],\n", - " [ 0.1876, 0.1876, -0.1876],\n", - " [-0.1876, 0.1876, 0.1876]],\n", + " [[-1., -1., 1.],\n", + " [ 1., 1., 1.],\n", + " [ 1., -1., 1.]],\n", "\n", - " [[-0.1876, -0.1876, -0.1876],\n", - " [ 0.1876, 0.1876, 0.1876],\n", - " [-0.1876, 0.1876, -0.1876]]],\n", + " [[-1., 1., 1.],\n", + " [-1., 1., -1.],\n", + " [-1., 1., 1.]]],\n", "\n", "\n", - " [[[-0.1876, -0.1876, -0.1876],\n", - " [ 0.1876, 0.1876, -0.1876],\n", - " [ 0.1876, -0.1876, -0.1876]],\n", + " [[[ 1., -1., -1.],\n", + " [ 1., -1., -1.],\n", + " [-1., 1., 1.]],\n", "\n", - " [[-0.1876, 0.1876, -0.1876],\n", - " [ 0.1876, -0.1876, -0.1876],\n", - " [-0.1876, -0.1876, 0.1876]],\n", + " [[ 1., 1., 1.],\n", + " [-1., 1., 1.],\n", + " [-1., 1., 1.]],\n", "\n", - " [[-0.1876, 0.1876, 0.1876],\n", - " [ 0.1876, -0.1876, 0.1876],\n", - " [-0.1876, -0.1876, -0.1876]]]], grad_fn=), scale=tensor(0.1876, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[-1., -1., 1.],\n", + " [ 1., 1., -1.],\n", + " [ 1., 1., 1.]]]], grad_fn=), scale=tensor(0.1910, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 20, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -793,7 +821,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -802,7 +830,7 @@ "True" ] }, - "execution_count": 21, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -820,16 +848,16 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor(0.1897, grad_fn=)" + "tensor(0.1880, grad_fn=)" ] }, - "execution_count": 22, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -850,7 +878,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": { "tags": [ "raises-exception" @@ -862,11 +890,11 @@ "evalue": "Error(s) in loading state_dict for QuantConv2d:\n\tMissing key(s) in state_dict: \"weight_quant.tensor_quant.scaling_impl.value\". ", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mparam_from_max_quant_conv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfloat_conv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32mC:\\ProgramData\\Miniconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[1;34m(self, state_dict, strict)\u001b[0m\n\u001b[0;32m 1405\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1406\u001b[0m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[1;32m-> 1407\u001b[1;33m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[0;32m 1408\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_IncompatibleKeys\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmissing_keys\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munexpected_keys\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1409\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantConv2d:\n\tMissing key(s) in state_dict: \"weight_quant.tensor_quant.scaling_impl.value\". " + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_625769/50754285.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mparam_from_max_quant_conv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfloat_conv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[0;34m(self, state_dict, strict)\u001b[0m\n\u001b[1;32m 1670\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1671\u001b[0m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[0;32m-> 1672\u001b[0;31m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[1;32m 1673\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_IncompatibleKeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmissing_keys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munexpected_keys\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1674\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantConv2d:\n\tMissing key(s) in state_dict: \"weight_quant.tensor_quant.scaling_impl.value\". " ] } ], @@ -916,39 +944,39 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1897, -0.1897, 0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, -0.1897]],\n", + "QuantTensor(int_value=tensor([[[[-1., -1., 1.],\n", + " [-1., 1., 1.],\n", + " [-1., -1., 1.]],\n", "\n", - " [[-0.1897, 0.1897, 0.1897],\n", - " [ 0.1897, -0.1897, -0.1897],\n", - " [ 0.1897, -0.1897, 0.1897]],\n", + " [[-1., 1., -1.],\n", + " [ 1., 1., -1.],\n", + " [-1., -1., -1.]],\n", "\n", - " [[-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, 0.1897],\n", - " [-0.1897, 0.1897, 0.1897]]],\n", + " [[ 1., 1., -1.],\n", + " [-1., 1., 1.],\n", + " [ 1., 1., -1.]]],\n", "\n", "\n", - " [[[ 0.1897, 0.1897, 0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897]],\n", + " [[[-1., 1., -1.],\n", + " [ 1., 1., -1.],\n", + " [ 1., -1., -1.]],\n", "\n", - " [[ 0.1897, -0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]],\n", + " [[-1., -1., 1.],\n", + " [ 1., 1., -1.],\n", + " [ 1., 1., 1.]],\n", "\n", - " [[-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]]]], grad_fn=), scale=tensor(0.1897, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[-1., -1., -1.],\n", + " [-1., 1., -1.],\n", + " [ 1., -1., -1.]]]], grad_fn=), scale=tensor(0.1880, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 25, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -979,7 +1007,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -1013,18 +1041,7 @@ "cell_type": "code", "execution_count": 26, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=MySignedBinaryWeightQuantizer)\n", "quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=quant_conv1.weight_quant)\n", @@ -1036,19 +1053,7 @@ "cell_type": "code", "execution_count": 27, "metadata": {}, - "outputs": [ - { - "ename": "AssertionError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_58415/1066539094.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mquant_conv1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_weight_scale\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mquant_conv2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_weight_scale\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ "assert (quant_conv1.quant_weight_scale() == quant_conv2.quant_weight_scale()).item()" ] @@ -1065,18 +1070,7 @@ "cell_type": "code", "execution_count": 28, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "class SharedParamFromMeanWeightQuantizer(MySignedBinaryWeightQuantizer):\n", " \n", @@ -1097,7 +1091,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -1140,7 +1134,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -1159,7 +1153,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -1260,42 +1254,42 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1842, 0.1842, -0.1842],\n", - " [-0.1842, -0.1842, 0.1842],\n", - " [-0.1842, -0.1842, 0.1842]],\n", + "QuantTensor(int_value=tensor([[[[ 1., -1., 1.],\n", + " [ 1., 1., -1.],\n", + " [ 1., 1., -1.]],\n", "\n", - " [[-0.1842, -0.1842, 0.1842],\n", - " [ 0.1842, -0.1842, 0.1842],\n", - " [ 0.1842, 0.1842, -0.1842]],\n", + " [[-1., 1., -1.],\n", + " [-1., -1., -1.],\n", + " [-1., -1., -1.]],\n", "\n", - " [[-0.1842, -0.1842, 0.1842],\n", - " [ 0.1842, 0.1842, 0.1842],\n", - " [-0.1842, 0.1842, -0.1842]]],\n", + " [[-1., -1., -1.],\n", + " [ 1., 1., -1.],\n", + " [-1., 1., 1.]]],\n", "\n", "\n", - " [[[ 0.1838, 0.1838, 0.1838],\n", - " [-0.1838, -0.1838, -0.1838],\n", - " [ 0.1838, 0.1838, -0.1838]],\n", + " [[[-1., -1., -1.],\n", + " [ 1., -1., 1.],\n", + " [-1., 1., 1.]],\n", "\n", - " [[ 0.1838, -0.1838, 0.1838],\n", - " [ 0.1838, 0.1838, 0.1838],\n", - " [-0.1838, 0.1838, -0.1838]],\n", + " [[-1., -1., -1.],\n", + " [-1., 1., 1.],\n", + " [ 1., 1., 1.]],\n", "\n", - " [[-0.1838, 0.1838, -0.1838],\n", - " [ 0.1838, -0.1838, -0.1838],\n", - " [ 0.1838, -0.1838, 0.1838]]]], grad_fn=), scale=tensor([[[[0.1842]]],\n", + " [[ 1., -1., -1.],\n", + " [ 1., 1., -1.],\n", + " [-1., -1., 1.]]]], grad_fn=), scale=tensor([[[[0.1912]]],\n", "\n", "\n", - " [[[0.1838]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[[0.1853]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 35, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -1318,42 +1312,42 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1875, -0.1875, 0.1875],\n", - " [-0.1875, 0.1875, -0.1875],\n", - " [-0.1875, 0.1875, -0.1875]],\n", + "QuantTensor(int_value=tensor([[[[-1., -1., 1.],\n", + " [-1., 1., 1.],\n", + " [-1., -1., 1.]],\n", "\n", - " [[-0.1875, 0.1875, 0.1875],\n", - " [ 0.1875, -0.1875, -0.1875],\n", - " [ 0.1875, -0.1875, 0.1875]],\n", + " [[-1., 1., -1.],\n", + " [ 1., 1., -1.],\n", + " [-1., -1., -1.]],\n", "\n", - " [[-0.1875, 0.1875, -0.1875],\n", - " [-0.1875, 0.1875, 0.1875],\n", - " [-0.1875, 0.1875, 0.1875]]],\n", + " [[ 1., 1., -1.],\n", + " [-1., 1., 1.],\n", + " [ 1., 1., -1.]]],\n", "\n", "\n", - " [[[ 0.1897, 0.1897, 0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897]],\n", + " [[[-1., 1., -1.],\n", + " [ 1., 1., -1.],\n", + " [ 1., -1., -1.]],\n", "\n", - " [[ 0.1897, -0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]],\n", + " [[-1., -1., 1.],\n", + " [ 1., 1., -1.],\n", + " [ 1., 1., 1.]],\n", "\n", - " [[-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]]]], grad_fn=), scale=tensor([[[[0.1875]]],\n", + " [[-1., -1., -1.],\n", + " [-1., 1., -1.],\n", + " [ 1., -1., -1.]]]], grad_fn=), scale=tensor([[[[0.1880]]],\n", "\n", "\n", - " [[[0.1897]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[[0.1873]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 36, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1374,19 +1368,19 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[-0.0100, -0.0100, 0.0100, -0.0100],\n", - " [-0.0100, -0.0100, -0.0100, 0.0100],\n", - " [-0.0100, 0.0100, 0.0100, 0.0100],\n", - " [-0.0100, 0.0100, 0.0100, 0.0100]], grad_fn=)" + "tensor([[ 0.0100, 0.0100, -0.0100, 0.0100],\n", + " [-0.0100, 0.0100, -0.0100, -0.0100],\n", + " [-0.0100, 0.0100, 0.0100, -0.0100],\n", + " [-0.0100, -0.0100, 0.0100, -0.0100]], grad_fn=)" ] }, - "execution_count": 37, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -1421,21 +1415,21 @@ "evalue": "'AdvancedActQuantizer' can not resolve attribute 'per_channel_broadcastable_shape'", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mDependencyError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m quant_identity = QuantIdentity(\n\u001b[1;32m----> 4\u001b[1;33m act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=True)\n\u001b[0m", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_activation.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 135\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 136\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 137\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_layer.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[0mpassthrough_act\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 79\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 80\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 81\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\act.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, act_quant, **kwargs)\u001b[0m\n\u001b[0;32m 157\u001b[0m \u001b[0mproxy_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'act_'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 158\u001b[0m \u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m''\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 159\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 160\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 161\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\base.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mfilter_kwargs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[0mquant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mproxy_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mproxy_protocol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[0;32m 108\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 109\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 110\u001b[1;33m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mActQuantProxyFromInjector\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 111\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_passthrough_act\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_passthrough_act\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector, export_mode, export_handler)\u001b[0m\n\u001b[0;32m 74\u001b[0m \u001b[1;31m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 75\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 76\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_tracked_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 77\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_handler\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_mode\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36madd_tracked_module\u001b[1;34m(self, module)\u001b[0m\n\u001b[0;32m 130\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 131\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate_tracked_modules\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 132\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 133\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Trying to add None as a parent module.\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36minit_tensor_quant\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 121\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 122\u001b[1;33m \u001b[0mtensor_quant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 123\u001b[0m \u001b[0mact_impl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 124\u001b[0m \u001b[0mis_act_enabled\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_act_enabled\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtensor_quant\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - " \u001b[1;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[1;32mC:\\ProgramData\\Miniconda3\\envs\\pytorch\\lib\\site-packages\\_dependencies\\this.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, __self__)\u001b[0m\n\u001b[0;32m 49\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mkind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m\".\"\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 50\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 51\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msymbol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 52\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mDependencyError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 53\u001b[0m message = (\n", - " \u001b[1;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[1;31mDependencyError\u001b[0m: 'AdvancedActQuantizer' can not resolve attribute 'per_channel_broadcastable_shape'" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mDependencyError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_625769/833415959.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m quant_identity = QuantIdentity(\n\u001b[0;32m----> 4\u001b[0;31m act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=True)\n\u001b[0m", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_activation.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mact_quant\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_quant_tensor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m **kwargs)\n\u001b[0m", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0mQuantLayerMixin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mQuantInputMixin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_quant\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m \u001b[0mQuantNonLinearActMixin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mact_impl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpassthrough_act\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/act.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, act_impl, passthrough_act, act_quant, act_proxy_prefix, act_kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0mnone_quant_injector\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mNoneActQuant\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mprefixed_kwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m **kwargs)\n\u001b[0m\u001b[1;32m 127\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/base.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mfilter_kwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs_prefix\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0mquant\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproxy_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproxy_protocol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_layer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 89\u001b[0;31m \u001b[0mQuantProxyFromInjector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_layer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 90\u001b[0m \u001b[0mActQuantProxyProtocol\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_passthrough_act\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_is_passthrough_act\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0;31m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_tracked_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 83\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisable_quant\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py\u001b[0m in \u001b[0;36madd_tracked_module\u001b[0;34m(self, module)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_tracked_modules\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 120\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_tensor_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 121\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Trying to add None as a parent module.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py\u001b[0m in \u001b[0;36minit_tensor_quant\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minit_tensor_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m \u001b[0mtensor_quant\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 103\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'act_impl'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0mact_impl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/_dependencies/this.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, __self__)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mkind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\".\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msymbol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mDependencyError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m message = (\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/inject/__init__.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(cls, attrname)\u001b[0m\n\u001b[1;32m 127\u001b[0m message = \"{!r} can not resolve attribute {!r}\".format(\n\u001b[1;32m 128\u001b[0m cls.__name__, current_attr)\n\u001b[0;32m--> 129\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mDependencyError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 130\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mmarker\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattribute\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhave_defaults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDependencyError\u001b[0m: 'AdvancedActQuantizer' can not resolve attribute 'per_channel_broadcastable_shape'" ] } ], @@ -1455,22 +1449,22 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.0100, 0.0100, -0.0100, -0.0100],\n", - " [-0.0100, 0.0100, -0.0100, -0.0100],\n", - " [ 0.0100, -0.0100, 0.0100, -0.0100],\n", - " [ 0.0100, -0.0100, -0.0100, -0.0100]], grad_fn=), scale=tensor([[0.0100],\n", + "QuantTensor(int_value=tensor([[ 1., 1., -1., 1.],\n", + " [-1., 1., -1., 1.],\n", + " [-1., 1., -1., 1.],\n", + " [ 1., 1., 1., -1.]], grad_fn=), scale=tensor([[0.0100],\n", " [0.0100],\n", " [0.0100],\n", " [0.0100]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 39, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -1506,7 +1500,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.16" } }, "nbformat": 4, diff --git a/notebooks/Brevitas_TVMCon2021.ipynb b/notebooks/Brevitas_TVMCon2021.ipynb index efd9421f0..86607c530 100644 --- a/notebooks/Brevitas_TVMCon2021.ipynb +++ b/notebooks/Brevitas_TVMCon2021.ipynb @@ -31,6 +31,14 @@ "execution_count": 1, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, { "data": { "text/markdown": [ @@ -45,8 +53,10 @@ " input_quant: Optional[ActQuantType] = None,\n", " output_quant: Optional[ActQuantType] = None,\n", " return_quant_tensor: bool = False,\n", + " device: Optional[torch.device] = None,\n", + " dtype: Optional[torch.dtype] = None,\n", " **kwargs) -> None:\n", - " Linear.__init__(self, in_features, out_features, bias)\n", + " Linear.__init__(self, in_features, out_features, bias, device=device, dtype=dtype)\n", " QuantWBIOL.__init__(\n", " self,\n", " weight_quant=weight_quant,\n", @@ -100,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -114,11 +124,17 @@ " [-0.2723, 0.1896],\n", " [-0.0140, 0.5607]], requires_grad=True) \n", "\n", - "Quantized weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.0046, 0.3803],\n", + "Quantized integer weight QuantTensor:\n", + " QuantTensor(int_value=tensor([[ -1., 83.],\n", + " [-127., -114.],\n", + " [ -59., 41.],\n", + " [ -3., 122.]], grad_fn=), scale=tensor(0.0046, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + "\n", + "Fake quantized weight QuantTensor:\n", + " tensor([[-0.0046, 0.3803],\n", " [-0.5820, -0.5224],\n", " [-0.2704, 0.1879],\n", - " [-0.0137, 0.5591]], grad_fn=), scale=tensor(0.0046, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " [-0.0137, 0.5591]], grad_fn=) \n", "\n" ] } @@ -131,7 +147,8 @@ "quant_linear = QuantLinear(2, 4, bias=True)\n", "\n", "print(f\"Original float weight tensor:\\n {quant_linear.weight} \\n\")\n", - "print(f\"Quantized weight QuantTensor:\\n {quant_linear.quant_weight()} \\n\")" + "print(f\"Quantized integer weight QuantTensor:\\n {quant_linear.quant_weight()} \\n\")\n", + "print(f\"Fake quantized weight QuantTensor:\\n {quant_linear.quant_weight().value} \\n\")" ] }, { @@ -150,7 +167,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -158,10 +175,10 @@ "output_type": "stream", "text": [ "Quantized Weight integer tensor:\n", - " tensor([[ -1, 83],\n", - " [-127, -114],\n", - " [ -59, 41],\n", - " [ -3, 122]], dtype=torch.int32)\n" + " tensor([[ -1., 83.],\n", + " [-127., -114.],\n", + " [ -59., 41.],\n", + " [ -3., 122.]], grad_fn=)\n" ] } ], @@ -180,7 +197,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -194,7 +211,15 @@ "Float output:\n", " tensor([[-0.9036, -0.4586, 0.3096, -0.6472],\n", " [ 1.2058, 0.6525, -0.3723, 0.8677],\n", - " [ 1.3873, 0.2801, -0.9009, 0.9507]], grad_fn=)\n" + " [ 1.3873, 0.2801, -0.9009, 0.9507]], grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/_tensor.py:1255: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525553989/work/c10/core/TensorImpl.h:1758.)\n", + " return super(Tensor, self).rename(names)\n" ] } ], @@ -227,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -235,10 +260,10 @@ "output_type": "stream", "text": [ "Weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.0078, 0.3828],\n", - " [-0.5781, -0.5234],\n", - " [-0.2734, 0.1875],\n", - " [-0.0156, 0.5625]], grad_fn=), scale=tensor(0.0078, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n", + " QuantTensor(int_value=tensor([[ -1., 49.],\n", + " [-74., -67.],\n", + " [-35., 24.],\n", + " [ -2., 72.]], grad_fn=), scale=tensor(0.0078, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n", "Weight fix point: 7.0\n" ] } @@ -266,7 +291,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -274,10 +299,10 @@ "output_type": "stream", "text": [ "Weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.1000, 0.1000],\n", - " [-0.1000, -0.1000],\n", - " [-0.1000, 0.1000],\n", - " [-0.1000, 0.1000]], grad_fn=), scale=tensor(0.1000), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))\n" + " QuantTensor(int_value=tensor([[-1., 1.],\n", + " [-1., -1.],\n", + " [-1., 1.],\n", + " [-1., 1.]], grad_fn=), scale=tensor(0.1000), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))\n" ] } ], @@ -311,7 +336,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -357,7 +382,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -372,7 +397,7 @@ "Quant output:\n", " tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", " [ 1.2089, 0.6524, -0.3752, 0.8697],\n", - " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=)\n" + " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=)\n" ] } ], @@ -399,7 +424,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -407,9 +432,9 @@ "output_type": "stream", "text": [ "Quant output:\n", - " QuantTensor(value=tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", - " [ 1.2089, 0.6524, -0.3752, 0.8697],\n", - " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" + " QuantTensor(int_value=tensor([[-10061., -5090., 3463., -7204.],\n", + " [ 13352., 7205., -4144., 9606.],\n", + " [ 15344., 3110., -9952., 10516.]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" ] } ], @@ -437,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -450,14 +475,14 @@ " [-1.0845, -1.3986]]) \n", "\n", "Quant input:\n", - " QuantTensor(value=tensor([[ 1.5490, -0.2894],\n", - " [-2.1788, 0.5617],\n", - " [-1.0894, -1.3958]], grad_fn=), scale=tensor(0.0170, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " QuantTensor(int_value=tensor([[ 91., -17.],\n", + " [-128., 33.],\n", + " [ -64., -82.]], grad_fn=), scale=tensor(0.0170, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", "\n", "Quant output:\n", - " QuantTensor(value=tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", - " [ 1.2089, 0.6524, -0.3752, 0.8697],\n", - " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" + " QuantTensor(int_value=tensor([[-10061., -5090., 3463., -7204.],\n", + " [ 13352., 7205., -4144., 9606.],\n", + " [ 15344., 3110., -9952., 10516.]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" ] } ], @@ -496,7 +521,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -509,9 +534,9 @@ " [-1.0845, -1.3986]]) \n", "\n", "Quant output:\n", - " QuantTensor(value=tensor([[1.5410, 0.0000],\n", - " [0.0000, 0.5681],\n", - " [0.0000, 0.0000]], grad_fn=), scale=tensor(0.0060, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))\n" + " QuantTensor(int_value=tensor([[255., 0.],\n", + " [ 0., 94.],\n", + " [ 0., 0.]], grad_fn=), scale=tensor(0.0060, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))\n" ] } ], @@ -540,7 +565,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -553,13 +578,13 @@ " [-1.0845, -1.3986]]) \n", "\n", "Quant output after QuantIdentity:\n", - " QuantTensor(value=tensor([[ 1.5490, -0.2894],\n", - " [-2.1788, 0.5617],\n", - " [-1.0894, -1.3958]], grad_fn=), scale=tensor(0.0170, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n", + " QuantTensor(int_value=tensor([[ 91., -17.],\n", + " [-128., 33.],\n", + " [ -64., -82.]], grad_fn=), scale=tensor(0.0170, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n", "Quant output after QuantReLU:\n", - " QuantTensor(value=tensor([[1.5490, 0.0000],\n", - " [0.0000, 0.5588],\n", - " [0.0000, 0.0000]], grad_fn=), scale=tensor(0.0061, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))\n" + " QuantTensor(int_value=tensor([[255., 0.],\n", + " [ 0., 92.],\n", + " [ 0., 0.]], grad_fn=), scale=tensor(0.0061, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))\n" ] } ], @@ -602,7 +627,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": { "tags": [ "raises-exception" @@ -614,15 +639,15 @@ "evalue": "Input scale required", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mC:\\Users\\ALESSA~1\\AppData\\Local\\Temp/ipykernel_18920/2660651517.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mquant_linear\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mQuantLinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m4\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias_quant\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mInt16Bias\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m \u001b[0mquant_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_linear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfloat_input\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32m~\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1051\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1052\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\nn\\quant_linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 96\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 97\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 98\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 99\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\nn\\quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[1;34m(self, inp)\u001b[0m\n\u001b[0;32m 355\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 356\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 357\u001b[1;33m \u001b[0mquant_bias\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 358\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 359\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1051\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1052\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\proxy\\parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[0;32m 194\u001b[0m \u001b[0mimpl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[1;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 195\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 196\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Input scale required\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 197\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 198\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Input bit-width required\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mRuntimeError\u001b[0m: Input scale required" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_626130/2660651517.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mquant_linear\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mQuantLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias_quant\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mInt16Bias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mquant_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquant_linear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfloat_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 66\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 67\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 339\u001b[0;31m \u001b[0mquant_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 340\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 341\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0mimpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 206\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input scale required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 207\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input bit-width required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Input scale required" ] } ], @@ -646,18 +671,18 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.6541, 0.1263, 0.1680, -0.1231],\n", - " [ 1.4658, 1.2395, -0.5207, 1.3989],\n", - " [ 1.6461, 0.8687, -1.0466, 1.4813]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(18.), signed_t=tensor(True), training_t=tensor(True))" + "QuantTensor(int_value=tensor([[ -7224., 1395., 1856., -1360.],\n", + " [ 16189., 13690., -5751., 15450.],\n", + " [ 18181., 9595., -11559., 16360.]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(18.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 14, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -703,7 +728,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -711,17 +736,17 @@ "output_type": "stream", "text": [ "Eval mode add quant inputs:\n", - " QuantTensor(value=tensor([[ 1.5335, -0.2875],\n", - " [-2.0447, 0.5751],\n", - " [-1.0863, -1.4057]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) \n", - " QuantTensor(value=tensor([[ 0.3994, 0.8307],\n", - " [-0.7188, -0.3994],\n", - " [-0.5910, 0.1757]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) \n", + " QuantTensor(int_value=tensor([[ 96., -18.],\n", + " [-128., 36.],\n", + " [ -68., -88.]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) \n", + " QuantTensor(int_value=tensor([[ 25., 52.],\n", + " [-45., -25.],\n", + " [-37., 11.]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) \n", "\n", "Eval mode add quant output:\n", - " QuantTensor(value=tensor([[ 1.9329, 0.5431],\n", - " [-2.7636, 0.1757],\n", - " [-1.6773, -1.2300]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(9.), signed_t=tensor(True), training_t=tensor(False))\n" + " QuantTensor(int_value=tensor([[ 121., 34.],\n", + " [-173., 11.],\n", + " [-105., -77.]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(9.), signed_t=tensor(True), training_t=tensor(False))\n" ] } ], @@ -769,7 +794,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -777,32 +802,24 @@ "output_type": "stream", "text": [ "Quant input:\n", - " QuantTensor(value=tensor([[[-1.1218, -1.1580, -0.2533, -0.4343],\n", - " [ 0.8504, 0.6876, -0.3076, -2.1170]],\n", + " QuantTensor(int_value=tensor([[[ -62., -64., -14., -24.],\n", + " [ 47., 38., -17., -117.]],\n", "\n", - " [[ 0.4704, -0.1628, 1.4475, 0.2714],\n", - " [ 0.1628, 0.8685, -0.1448, -0.1086]],\n", + " [[ 26., -9., 80., 15.],\n", + " [ 9., 48., -8., -6.]],\n", "\n", - " [[ 0.9228, 1.2666, 2.0084, 0.0543],\n", - " [ 0.6152, -0.4162, -0.8323, -2.3160]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " [[ 51., 70., 111., 3.],\n", + " [ 34., -23., -46., -128.]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", "\n", "Quant output:\n", - " QuantTensor(value=tensor([[[-1.1218, -0.2533],\n", - " [ 0.8504, -0.3076]],\n", + " QuantTensor(int_value=tensor([[[-62., -14.],\n", + " [ 47., -17.]],\n", "\n", - " [[ 0.4704, 1.4475],\n", - " [ 0.8685, -0.1086]],\n", + " [[ 26., 80.],\n", + " [ 48., -6.]],\n", "\n", - " [[ 1.2666, 2.0084],\n", - " [ 0.6152, -0.8323]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\Alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\functional.py:652: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ..\\c10/core/TensorImpl.h:1156.)\n", - " return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)\n" + " [[ 70., 111.],\n", + " [ 34., -46.]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n" ] } ], @@ -830,7 +847,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -838,14 +855,14 @@ "output_type": "stream", "text": [ "Quant input:\n", - " QuantTensor(value=tensor([[[-1.1218, -1.1580, -0.2533, -0.4343],\n", - " [ 0.8504, 0.6876, -0.3076, -2.1170]],\n", + " QuantTensor(int_value=tensor([[[ -62., -64., -14., -24.],\n", + " [ 47., 38., -17., -117.]],\n", "\n", - " [[ 0.4704, -0.1628, 1.4475, 0.2714],\n", - " [ 0.1628, 0.8685, -0.1448, -0.1086]],\n", + " [[ 26., -9., 80., 15.],\n", + " [ 9., 48., -8., -6.]],\n", "\n", - " [[ 0.9228, 1.2666, 2.0084, 0.0543],\n", - " [ 0.6152, -0.4162, -0.8323, -2.3160]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " [[ 51., 70., 111., 3.],\n", + " [ 34., -23., -46., -128.]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", "\n", "Quant output:\n", " tensor([[[-0.8082, -0.8204, -0.2480, -0.4089],\n", @@ -855,7 +872,15 @@ " [ 0.1614, 0.7006, -0.1438, -0.1081]],\n", "\n", " [[ 0.7272, 0.8529, 0.9646, 0.0542],\n", - " [ 0.5478, -0.3937, -0.6817, -0.9807]]], grad_fn=)\n" + " [ 0.5478, -0.3937, -0.6817, -0.9807]]], grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/ipykernel_launcher.py:7: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525553989/work/torch/csrc/utils/python_arg_parser.cpp:350.)\n", + " import sys\n" ] } ], @@ -883,24 +908,29 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "Eval mode concat quant inputs:\n", - " QuantTensor(value=tensor([[ 1.5335, -0.2875],\n", - " [-2.0447, 0.5751],\n", - " [-1.0863, -1.4057]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) QuantTensor(value=tensor([[ 0.3994, 0.8307],\n", - " [-0.7188, -0.3994],\n", - " [-0.5910, 0.1757]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) \n", - "\n", - "Eval mode concat quant output:\n", - " QuantTensor(value=tensor([[ 1.5335, -0.2875, 0.3994, 0.8307],\n", - " [-2.0447, 0.5751, -0.7188, -0.3994],\n", - " [-1.0863, -1.4057, -0.5910, 0.1757]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False))\n" + "/home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/ipykernel_launcher.py:8: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525553989/work/torch/csrc/utils/python_arg_parser.cpp:350.)\n", + " \n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "Scaling factors are different", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_626130/3932472163.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;31m#Training mode, statistics are being collected, scaling factors are different but it doesn't raise an error\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mtrain_mode_cat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mquant_identity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfloat_inp1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_identity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfloat_inp2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;31m#Inference mode, the EMA buffer is being used, scaling factors are the same\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/quant_tensor/__init__.py\u001b[0m in \u001b[0;36m__torch_function__\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_unpack_quant_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 107\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mQUANT_TENSOR_FN_HANDLER\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/quant_tensor/torch_handler.py\u001b[0m in \u001b[0;36mcat_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcat_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mbrevitas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_tensor\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/quant_tensor/__init__.py\u001b[0m in \u001b[0;36mcat\u001b[0;34m(tensors, dim, out)\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mqt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mfirst_qt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_scaling_factors_same\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 306\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Scaling factors are different\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 307\u001b[0m \u001b[0mfirst_qt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_zero_points_same\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0mfirst_qt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_bit_width_same\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Scaling factors are different" ] } ], @@ -946,7 +976,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -954,10 +984,10 @@ "output_type": "stream", "text": [ "Weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.0000, 0.3880],\n", - " [-0.5820, -0.5044],\n", - " [-0.2716, 0.1940],\n", - " [-0.0000, 0.5432]], grad_fn=), scale=tensor(0.0388, grad_fn=), zero_point=tensor(0.), bit_width=tensor(5.), signed_t=tensor(True), training_t=tensor(True))\n" + " QuantTensor(int_value=tensor([[ -0., 10.],\n", + " [-15., -13.],\n", + " [ -7., 5.],\n", + " [ -0., 14.]], grad_fn=), scale=tensor(0.0388, grad_fn=), zero_point=tensor(0.), bit_width=tensor(5.), signed_t=tensor(True), training_t=tensor(True))\n" ] } ], @@ -980,7 +1010,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -988,10 +1018,10 @@ "output_type": "stream", "text": [ "Weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.0000, 0.3793],\n", - " [-0.5820, -0.5044],\n", - " [-0.2723, 0.1816],\n", - " [-0.0000, 0.5607]], grad_fn=), scale=tensor([[0.0253],\n", + " QuantTensor(int_value=tensor([[ -0., 15.],\n", + " [-15., -13.],\n", + " [-15., 10.],\n", + " [ -0., 15.]], grad_fn=), scale=tensor([[0.0253],\n", " [0.0388],\n", " [0.0182],\n", " [0.0374]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(5.), signed_t=tensor(True), training_t=tensor(True))\n" @@ -1017,7 +1047,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -1025,9 +1055,9 @@ "output_type": "stream", "text": [ "QuantTensor:\n", - " QuantTensor(value=tensor([[ 1.6341, -0.5447],\n", - " [-2.1788, 0.5447],\n", - " [-1.0894, -1.6341]], grad_fn=), scale=tensor(0.5447, grad_fn=), zero_point=tensor(0.), bit_width=tensor(3.), signed_t=tensor(True), training_t=tensor(True))\n" + " QuantTensor(int_value=tensor([[ 3., -1.],\n", + " [-4., 1.],\n", + " [-2., -3.]], grad_fn=), scale=tensor(0.5447, grad_fn=), zero_point=tensor(0.), bit_width=tensor(3.), signed_t=tensor(True), training_t=tensor(True))\n" ] } ], @@ -1050,18 +1080,18 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[1.5294, 0.0000],\n", - " [0.0000, 0.5647],\n", - " [0.0000, 0.0000]], grad_fn=), scale=tensor(0.0235, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" + "QuantTensor(int_value=tensor([[65., 0.],\n", + " [ 0., 24.],\n", + " [ 0., 0.]], grad_fn=), scale=tensor(0.0235, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" ] }, - "execution_count": 22, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -1087,7 +1117,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -1099,8 +1129,8 @@ " [-1.3986, 0.4033, 0.8380, -0.7193, -0.4033]]]) \n", "\n", "Per-channel quant output:\n", - " QuantTensor(value=tensor([[[ 0.8616, -0.7012, 0.4503],\n", - " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", + " QuantTensor(int_value=tensor([[[ 419., -341., 219.],\n", + " [-855., -374., -144.]]], grad_fn=), scale=tensor([[[0.0021],\n", " [0.0013]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" ] } @@ -1145,7 +1175,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -1157,8 +1187,8 @@ " [-1.3986, 0.4033, 0.8380, -0.7193, -0.4033]]]) \n", "\n", "Per-channel quant output:\n", - " QuantTensor(value=tensor([[[ 0.8616, -0.7012, 0.4503],\n", - " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", + " QuantTensor(int_value=tensor([[[ 419., -341., 219.],\n", + " [-855., -374., -144.]]], grad_fn=), scale=tensor([[[0.0021],\n", " [0.0013]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" ] } @@ -1219,7 +1249,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -1253,7 +1283,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -1293,7 +1323,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -1301,10 +1331,10 @@ "output_type": "stream", "text": [ "Weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.0060, 0.3793],\n", - " [-0.5820, -0.5224],\n", - " [-0.2723, 0.1887],\n", - " [-0.0132, 0.5607]], grad_fn=), scale=tensor([[0.0030],\n", + " QuantTensor(int_value=tensor([[ -2., 127.],\n", + " [-127., -114.],\n", + " [-127., 88.],\n", + " [ -3., 127.]], grad_fn=), scale=tensor([[0.0030],\n", " [0.0046],\n", " [0.0021],\n", " [0.0044]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(8., grad_fn=), signed_t=tensor(True), training_t=tensor(True))\n" @@ -1337,19 +1367,19 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.9109, -0.4588, 0.3119, -0.6530],\n", - " [ 1.2089, 0.6493, -0.3731, 0.8706],\n", - " [ 1.3893, 0.2823, -0.8979, 0.9543]], grad_fn=), scale=tensor([[9.0542e-05, 3.9068e-05, 5.6866e-05, 6.4251e-05]],\n", - " grad_fn=), zero_point=tensor(0.), bit_width=tensor(17., grad_fn=), signed_t=tensor(True), training_t=tensor(True))" + "QuantTensor(int_value=tensor([[-10061., -11744., 5485., -10163.],\n", + " [ 13352., 16619., -6561., 13550.],\n", + " [ 15344., 7226., -15790., 14852.]], grad_fn=), scale=tensor([[9.0542e-05, 3.9068e-05, 5.6866e-05, 6.4251e-05]],\n", + " grad_fn=), zero_point=tensor(0.), bit_width=tensor(17., grad_fn=), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 28, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -1394,7 +1424,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 31, "metadata": { "tags": [ "raises-exception" @@ -1406,11 +1436,11 @@ "evalue": "Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". ", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mC:\\Users\\ALESSA~1\\AppData\\Local\\Temp/ipykernel_18920/1653109852.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 10\u001b[0m return_quant_tensor=True, bias=False)\n\u001b[0;32m 11\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m \u001b[0mquant_linear\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfloat_linear\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32m~\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[1;34m(self, state_dict, strict)\u001b[0m\n\u001b[0;32m 1405\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1406\u001b[0m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[1;32m-> 1407\u001b[1;33m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[0;32m 1408\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_IncompatibleKeys\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmissing_keys\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munexpected_keys\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1409\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". " + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_626130/1653109852.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m return_quant_tensor=True, bias=False)\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mquant_linear\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfloat_linear\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/envs/torch_1.13/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[0;34m(self, state_dict, strict)\u001b[0m\n\u001b[1;32m 1670\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1671\u001b[0m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[0;32m-> 1672\u001b[0;31m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[1;32m 1673\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_IncompatibleKeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmissing_keys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munexpected_keys\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1674\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". " ] } ], @@ -1440,7 +1470,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -1449,7 +1479,7 @@ "" ] }, - "execution_count": 30, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -1481,7 +1511,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -1548,7 +1578,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -1575,10 +1605,12 @@ " (stats): _Stats(\n", " (stats_impl): AbsPercentile()\n", " )\n", - " (restrict_clamp_scaling): _RestrictClampValue(\n", - " (clamp_min_ste): Identity()\n", + " (restrict_scaling): _RestrictValue(\n", " (restrict_value_impl): FloatRestrictValue()\n", " )\n", + " (clamp_scaling): _ClampValue(\n", + " (clamp_min_ste): ScalarClampMinSte()\n", + " )\n", " (restrict_inplace_preprocess): Identity()\n", " (restrict_preprocess): Identity()\n", " )\n", @@ -1595,7 +1627,7 @@ ")" ] }, - "execution_count": 32, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1617,7 +1649,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -1669,7 +1701,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -1729,7 +1761,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -1785,7 +1817,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 38, "metadata": {}, "outputs": [ { @@ -1794,7 +1826,7 @@ "True" ] }, - "execution_count": 36, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } @@ -1845,20 +1877,19 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: netron in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (5.3.9)\n", - "Requirement already satisfied: onnx in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (1.10.2)\n", - "Requirement already satisfied: onnxoptimizer in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (0.2.6)\n", - "Requirement already satisfied: numpy>=1.16.6 in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (1.21.2)\n", - "Requirement already satisfied: typing-extensions>=3.6.2.1 in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (3.10.0.2)\n", - "Requirement already satisfied: protobuf in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (3.19.1)\n", - "Requirement already satisfied: six in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (1.16.0)\n" + "Requirement already satisfied: netron in /home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages (7.1.8)\n", + "Requirement already satisfied: onnx in /home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages (1.11.0)\n", + "Requirement already satisfied: onnxoptimizer in /home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages (0.3.7)\n", + "Requirement already satisfied: protobuf>=3.12.2 in /home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages (from onnx) (3.20.1)\n", + "Requirement already satisfied: typing-extensions>=3.6.2.1 in /home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages (from onnx) (4.7.0)\n", + "Requirement already satisfied: numpy>=1.16.6 in /home/giuseppe/miniconda3/envs/torch_1.13/lib/python3.7/site-packages (from onnx) (1.21.6)\n" ] } ], @@ -1868,7 +1899,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -1894,9 +1925,38 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 41, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/Documents/git/brevitas/src/brevitas/export/onnx/standard/manager.py:26: UserWarning: ONNX opset version set to 13, override with opset_version=\n", + " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'Tensor' object has no attribute 'value'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_626130/3904900504.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0moutput_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'qop_onnx_conv_4b8b.onnx'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mexport_onnx_qop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_conv_4b8b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_t\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfloat_inp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexport_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/export/__init__.py\u001b[0m in \u001b[0;36mexport_onnx_qop\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mStdQOpONNXManager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mexport_onnx_qop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mStdQOpONNXManager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/export/onnx/manager.py\u001b[0m in \u001b[0;36mexport\u001b[0;34m(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m **onnx_export_kwargs):\n\u001b[1;32m 162\u001b[0m return cls.export_onnx(\n\u001b[0;32m--> 163\u001b[0;31m module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)\n\u001b[0m", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/export/onnx/standard/manager.py\u001b[0m in \u001b[0;36mexport_onnx\u001b[0;34m(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msolve_onnx_opset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0monnx_export_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m output = super().export_onnx(\n\u001b[0;32m---> 47\u001b[0;31m module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/export/onnx/manager.py\u001b[0m in \u001b[0;36mexport_onnx\u001b[0;34m(cls, module, args, export_path, input_shape, input_t, disable_warnings, **onnx_export_kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cache_inp_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mmodel_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 113\u001b[0;31m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cache_inp_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 114\u001b[0m \u001b[0;31m# Dequantize QuantTensor, if any and enabled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/export/manager.py\u001b[0m in \u001b[0;36m_cache_inp_out\u001b[0;34m(cls, module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 263\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmgr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cache_patches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 264\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menter_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmgr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 265\u001b[0;31m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 266\u001b[0m \u001b[0;31m# Restore previous caching properties\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 267\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0m_restore_quant_metadata_caching_mode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[0moutput_signed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 306\u001b[0;31m \u001b[0minp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munpack_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 307\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0;31m# shortcut execution through the export impl during export\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/base.py\u001b[0m in \u001b[0;36munpack_input\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;31m# inp = QuantTensor(inp, scale=torch.tensor(1.0, device=inp.device, dtype=inp.dtype), training=self.training)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_inference_quant_inp\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 172\u001b[0;31m \u001b[0mcached_inp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_quant_io_metadata_only\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 173\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_inp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcached_inp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0;31m# print(inp)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/base.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, quant_tensor, metadata_only)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_tensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquant_tensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_tensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquant_tensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'Tensor' object has no attribute 'value'" + ] + } + ], "source": [ "torch.manual_seed(0)\n", "\n", @@ -2280,7 +2340,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:42:03) [MSC v.1929 64 bit (AMD64)]" + "version": "3.7.16" }, "vscode": { "interpreter": { diff --git a/src/brevitas/core/quant/binary.py b/src/brevitas/core/quant/binary.py index 392cdeb62..d0761b9b5 100644 --- a/src/brevitas/core/quant/binary.py +++ b/src/brevitas/core/quant/binary.py @@ -58,7 +58,7 @@ def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0): @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) - y = binary_sign_ste(x) * scale + y = binary_sign_ste(x) #* scale y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width() @@ -119,6 +119,6 @@ def __init__( def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) y = self.tensor_clamp_impl(x, -scale, scale) - y = binary_sign_ste(y) * scale + y = binary_sign_ste(y) #* scale y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width() diff --git a/src/brevitas/core/quant/int_base.py b/src/brevitas/core/quant/int_base.py index b79f13316..d83a5dbfa 100644 --- a/src/brevitas/core/quant/int_base.py +++ b/src/brevitas/core/quant/int_base.py @@ -81,10 +81,11 @@ def max_int(self, bit_width): @brevitas.jit.script_method def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: - y_int = self.to_int(scale, zero_point, bit_width, x) - y = y_int - zero_point - y = y * scale + y = self.to_int(scale, zero_point, bit_width, x) + # y = y_int - zero_point + # y = y * scale y = self.delay_wrapper(x, y) + # print(f"Only once {y}") return y diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 194631953..b2bf0f154 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -477,7 +477,7 @@ def evaluate_loss(self, x, candidate): self.set_local_loss_mode(True) quant_value = self.proxy_forward(x) if isinstance(quant_value, tuple): - quant_value = quant_value[0] + quant_value = quant_value.value loss = self.mse_loss_fn(x, quant_value) self.set_local_loss_mode(False) return loss diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 9eed8b38e..82a891d33 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -142,7 +142,7 @@ def disable_act_quant_hook(self, module, inp, output): if isinstance(module.tracked_module_list[0], QuantHardTanh): inp = F.hardtanh( inp, min_val=module.quant_injector.min_val, max_val=module.quant_injector.max_val) - return QuantTensor(value=inp, training=module.training) + return inp def disable_act_quantization(self, model, is_training): # If self.call_act_quantizer_impl is set to True, the quantization will be performed but the output diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 2d4fa97ad..4fb92e3fc 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -167,14 +167,18 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]): cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp else: - inp = QuantTensor(inp, training=self.training) + # inp = QuantTensor(inp, scale=torch.tensor(1.0, device=inp.device, dtype=inp.dtype), training=self.training) if not self.training and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp + # print(inp) # Remove any naming metadata to avoid dowmstream errors # Avoid inplace operations on the input in case of forward hooks if not torch._C._get_tracing_state(): - inp = inp.set(value=inp.value.rename(None)) + if isinstance(inp, QuantTensor): + inp = inp.set(qt_value=inp.qt_value.rename(None)) + else: + inp = inp.rename(None) return inp def pack_output(self, quant_output: QuantTensor): @@ -184,7 +188,10 @@ def pack_output(self, quant_output: QuantTensor): if self.return_quant_tensor: return quant_output else: - return quant_output.value + if isinstance(quant_output, QuantTensor): + return quant_output.value + else: + return quant_output class QuantRecurrentLayerMixin(ExportMixin): @@ -246,9 +253,10 @@ def gate_params_fwd(gate, quant_input): acc_bit_width = None quant_weight_ih = gate.input_weight() quant_weight_hh = gate.hidden_weight() - if quant_input.bit_width is not None: + if isinstance(quant_input, QuantTensor): acc_bit_width = None # TODO - if quant_input.scale is not None and quant_weight_ih.scale is not None: + if getattr(quant_input, 'scale', None) is not None and getattr( + quant_weight_ih, 'scale', None) is not None: acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1) acc_scale = quant_weight_ih.scale.view(acc_scale_shape) acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape) @@ -267,8 +275,8 @@ def maybe_quantize_input(self, inp): quant_input = inp if not self.quantize_output_only: quant_input = self.io_quant(quant_input) - elif not isinstance(inp, QuantTensor): - quant_input = QuantTensor(quant_input) + # elif not isinstance(inp, QuantTensor): + # quant_input = QuantTensor(quant_input) return quant_input def maybe_quantize_state(self, inp, state, quant): @@ -276,7 +284,7 @@ def maybe_quantize_state(self, inp, state, quant): batch_size = inp.size(0) if self.cell.batch_first else inp.size(1) quant_state = torch.zeros( int(batch_size), self.hidden_size, dtype=inp.dtype, device=inp.device) - quant_state = QuantTensor(quant_state) + # quant_state = QuantTensor(quant_state) else: quant_state = quant(state) return quant_state @@ -303,7 +311,8 @@ def pack_quant_outputs(self, quant_outputs): quant_output[2], quant_output[3], self.io_quant.is_signed, - self.training) for quant_output in quant_outputs] + self.training, + _allow_empty=True) for quant_output in quant_outputs] else: outputs = [torch.unsqueeze(o[0], dim=seq_dim) for o in quant_outputs] if self.reverse_input: @@ -331,7 +340,8 @@ def pack_quant_state(self, quant_state, quant): quant_state[2], quant_state[3], quant.is_signed, - self.training) + training=self.training, + _allow_empty=True) else: quant_state = torch.unsqueeze(quant_state[0], dim=0) return quant_state diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 7208aa8e3..15cc8bce3 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -19,6 +19,10 @@ from .utils import rename_state_dict_by_prefix +def return_value(tensor): + return tensor.value if isinstance(tensor, QuantTensor) else tensor + + class QuantNonLinearActLayer(QuantNonLinearActMixin, QuantInputMixin, QuantLayerMixin, Module): __metaclass__ = ABCMeta @@ -308,56 +312,82 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe return out quant_input = self.input_quant(inp) + # quant_input_value = getattr(quant_input, 'value', quant_input) + # quant_input_scale = getattr(quant_input, 'scale', None) + # quant_input_bitwidth = getattr(quant_input, 'bit_width', None) + quant_weight = self.quant_weight(quant_input) + # quant_weight_value = getattr(quant_weight, 'value', quant_weight) + # quant_weight_scale = getattr(quant_weight, 'scale', None) + # quant_weight_bitwidth = getattr(quant_weight, 'bit_width', None) + + if ((not isinstance(quant_input, QuantTensor) or not isinstance(quant_weight, QuantTensor)) + and not self.is_output_quant_enabled) and self.return_quant_tensor: + raise RuntimeError("QuantLayer is not correctly configured") if (self.return_quant_tensor or (self.is_bias_quant_enabled and (self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width))): - if quant_input.bit_width is not None and quant_weight.bit_width is not None: + if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): output_bit_width = self.max_acc_bit_width( quant_input.bit_width, quant_weight.bit_width) - if quant_input.scale is not None and quant_weight.scale is not None: + output_scale = self.quant_output_scale_impl( inp, quant_input.scale, quant_weight.scale) - if quant_input.signed is not None: - output_signed = inp.signed or quant_weight.signed + + quant_input_signed = quant_input.signed if isinstance( + quant_input, QuantTensor) else True + quant_weight_signed = quant_weight.signed if isinstance( + quant_weight, QuantTensor) else True + output_signed = quant_input_signed or quant_weight_signed if self.bias is not None: quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width) + quant_bias_value = getattr(quant_bias, 'value', quant_bias) + quant_bias_scale = getattr(quant_bias, 'scale', None) + quant_bias_bitwidth = getattr(quant_bias, 'bit_width', None) if not self.training and self.cache_inference_quant_bias: self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) - output_tensor = self.inner_forward_impl( - quant_input.value, quant_weight.value, quant_bias.value) + return_value(quant_input), return_value(quant_weight), return_value(quant_bias)) if (self.return_quant_tensor and output_scale is not None and - (quant_bias.scale is None or - (quant_bias.scale is not None and - quant_bias.scale.data_ptr() != output_scale.data_ptr()))): - output_scale_broadcast_shape = compute_channel_view_shape(inp, channel_dim=1) - output_zero_point = -quant_bias.value.view( + (quant_bias_scale is None or + (quant_bias_scale is not None and + quant_bias_scale.data_ptr() != output_scale.data_ptr()))): + channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 + output_scale_broadcast_shape = compute_channel_view_shape( + inp, channel_dim=channel_dim) + output_zero_point = -quant_bias_value.view( output_scale_broadcast_shape) / output_scale - if quant_bias.bit_width is not None and output_bit_width is not None: + if hasattr(quant_bias, 'bit_width' + ) and quant_bias_bitwidth is not None and output_bit_width is not None: output_bit_width = torch.where( - quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width) + quant_bias_bitwidth > output_bit_width, quant_bias_bitwidth, output_bit_width) output_bit_width = output_bit_width + 1 else: - output_tensor = self.inner_forward_impl(quant_input.value, quant_weight.value, None) + output_tensor = self.inner_forward_impl( + return_value(quant_input), return_value(quant_weight), None) if self.return_quant_tensor and not self.is_output_quant_enabled: - if (quant_input.zero_point is not None and quant_weight.zero_point is not None and + if (isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor) and ((quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any())): raise RuntimeError("Computing zero point of output accumulator not supported yet.") elif quant_input.zero_point is not None and output_zero_point is None: output_zero_point = quant_input.zero_point + elif self.return_quant_tensor and output_zero_point is None: + output_zero_point = torch.zeros(1).type_as(output_tensor) - quant_output = QuantTensor( - value=output_tensor, - scale=output_scale, - zero_point=output_zero_point, - bit_width=output_bit_width, - signed=output_signed, - training=self.training) + if not self.return_quant_tensor or (output_scale is None and output_zero_point is None): + quant_output = output_tensor + else: + quant_output = QuantTensor.from_fake_quantized( + output_tensor, + scale=output_scale, + zero_point=output_zero_point, + bit_width=output_bit_width, + signed=output_signed, + training=self.training) quant_output = self.output_quant(quant_output) return self.pack_output(quant_output) diff --git a/src/brevitas/nn/quant_mha.py b/src/brevitas/nn/quant_mha.py index 6720fe280..020ed2d31 100644 --- a/src/brevitas/nn/quant_mha.py +++ b/src/brevitas/nn/quant_mha.py @@ -409,7 +409,7 @@ def multi_head_attention( # Mark dimensions through named tensors. if not torch._C._get_tracing_state(): if isinstance(query, QuantTensor): - query.value.rename_('L', 'N', 'E') + query.qt_value.rename_('L', 'N', 'E') else: query.rename_('L', 'N', 'E') # self-attention @@ -426,7 +426,7 @@ def multi_head_attention( if not torch._C._get_tracing_state(): for t in [query, key, value]: if isinstance(t, QuantTensor): - t.value.rename_('L', 'N', 'E') + t.qt_value.rename_('L', 'N', 'E') else: t.rename_('L', 'N', 'E') q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value) @@ -573,7 +573,7 @@ def multi_head_attention( # Remove names to avoid errors un unsupported downstream ops if not torch._C._get_tracing_state(): if isinstance(attn_output, QuantTensor): - attn_output.value.rename_(None) + attn_output.qt_value.rename_(None) else: attn_output.rename_(None) diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index 396a4f6ef..8268e9eb9 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -23,6 +23,8 @@ from brevitas.quant import Int8WeightPerTensorFloat from brevitas.quant import Int32Bias from brevitas.quant import Uint8ActPerTensorFloat +from brevitas.quant_tensor import _get_dequantize_tensor +from brevitas.quant_tensor import QuantTensor QuantTupleShortEnabled = List[Tuple[Tensor, Tensor, Tensor, Tensor]] QuantTupleShortDisabled = List[Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]] @@ -416,10 +418,10 @@ def forward(self, inp, state): quant_input = self.maybe_quantize_input(inp) quant_weight_ih, quant_weight_hh, quant_bias = self.gate_params_fwd( self.gate_params, quant_input) - if quant_bias.value is None: + if getattr(quant_bias, 'value', quant_bias) is None: quant_bias = torch.tensor(0., device=quant_input.value.device) else: - quant_bias = quant_bias.value + quant_bias = _get_dequantize_tensor(quant_bias) quant_state = self.maybe_quantize_state(quant_input.value, state, self.cell.output_quant) if self.export_mode: cell = self.export_handler @@ -428,10 +430,10 @@ def forward(self, inp, state): else: cell = self.cell quant_outputs = cell( - quant_input.value, - quant_state.value, - quant_weight_ih.value, - quant_weight_hh.value, + _get_dequantize_tensor(quant_input), + _get_dequantize_tensor(quant_state), + _get_dequantize_tensor(quant_weight_ih), + _get_dequantize_tensor(quant_weight_hh), quant_bias) quant_output = self.pack_quant_outputs(quant_outputs) quant_state = self.pack_quant_state(quant_outputs[-1], self.cell.output_quant) @@ -666,6 +668,7 @@ def fast_cell(self): def forward(self, inp, hidden_state, cell_state): quant_input = self.maybe_quantize_input(inp) + quant_input_value = _get_dequantize_tensor(quant_input) quant_weight_ii, quant_weight_hi, quant_bias_input = self.gate_params_fwd( self.input_gate_params, quant_input) quant_weight_ic, quant_weight_hc, quant_bias_cell = self.gate_params_fwd( @@ -680,26 +683,26 @@ def forward(self, inp, hidden_state, cell_state): quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd( self.forget_gate_params, quant_input) # Handle None bias by setting it 0. - if quant_bias_input.value is None: - quant_bias_input = torch.tensor(0., device=quant_input.value.device) + if getattr(quant_bias_input, 'value', quant_bias_input) is None: + quant_bias_input = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_input = quant_bias_input.value - if quant_bias_forget.value is None: - quant_bias_forget = torch.tensor(0., device=quant_input.value.device) + quant_bias_input = _get_dequantize_tensor(quant_bias_input) + if getattr(quant_bias_forget, 'value', quant_bias_forget) is None: + quant_bias_forget = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_forget = quant_bias_forget.value - if quant_bias_cell.value is None: - quant_bias_cell = torch.tensor(0., device=quant_input.value.device) + quant_bias_forget = _get_dequantize_tensor(quant_bias_forget) + if getattr(quant_bias_cell, 'value', quant_bias_cell) is None: + quant_bias_cell = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_cell = quant_bias_cell.value - if quant_bias_output.value is None: - quant_bias_output = torch.tensor(0., device=quant_input.value.device) + quant_bias_cell = _get_dequantize_tensor(quant_bias_cell) + if getattr(quant_bias_output, 'value', quant_bias_output) is None: + quant_bias_output = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_output = quant_bias_output.value + quant_bias_output = _get_dequantize_tensor(quant_bias_output) quant_hidden_state = self.maybe_quantize_state( - quant_input.value, hidden_state, self.cell.output_quant) + quant_input_value, hidden_state, self.cell.output_quant) quant_cell_state = self.maybe_quantize_state( - quant_input.value, cell_state, self.cell.cell_state_quant) + quant_input_value, cell_state, self.cell.cell_state_quant) # Pick cell impl if self.export_mode: cell = self.export_handler @@ -708,17 +711,17 @@ def forward(self, inp, hidden_state, cell_state): else: cell = self.cell quant_outputs, quant_hidden_state, quant_cell_state = cell( - quant_input.value, - quant_hidden_state.value, - quant_cell_state.value, - quant_weight_ii=quant_weight_ii.value, - quant_weight_if=quant_weight_if.value, - quant_weight_ic=quant_weight_ic.value, - quant_weight_io=quant_weight_io.value, - quant_weight_hi=quant_weight_hi.value, - quant_weight_hf=quant_weight_hf.value, - quant_weight_hc=quant_weight_hc.value, - quant_weight_ho=quant_weight_ho.value, + quant_input_value, + _get_dequantize_tensor(quant_hidden_state), + _get_dequantize_tensor(quant_cell_state), + quant_weight_ii=_get_dequantize_tensor(quant_weight_ii), + quant_weight_if=_get_dequantize_tensor(quant_weight_if), + quant_weight_ic=_get_dequantize_tensor(quant_weight_ic), + quant_weight_io=_get_dequantize_tensor(quant_weight_io), + quant_weight_hi=_get_dequantize_tensor(quant_weight_hi), + quant_weight_hf=_get_dequantize_tensor(quant_weight_hf), + quant_weight_hc=_get_dequantize_tensor(quant_weight_hc), + quant_weight_ho=_get_dequantize_tensor(quant_weight_ho), quant_bias_input=quant_bias_input, quant_bias_forget=quant_bias_forget, quant_bias_cell=quant_bias_cell, diff --git a/src/brevitas/nn/utils.py b/src/brevitas/nn/utils.py index 3e7b423ee..996f8f256 100644 --- a/src/brevitas/nn/utils.py +++ b/src/brevitas/nn/utils.py @@ -74,8 +74,8 @@ def check_tensors_same_ptr(tensor_list): if hasattr(t, 'data_ptr'): ptr = t.data_ptr() pointers.append(ptr) - elif hasattr(t, 'value') and hasattr(t.value, 'data_ptr'): - pointers.append(t.value.data_ptr()) + elif hasattr(t, 'qt_value') and hasattr(t.qt_value, 'data_ptr'): + pointers.append(t.qt_value.data_ptr()) else: return False return all(p == pointers[0] for p in pointers) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 5a4b2ed55..f7f120697 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -100,7 +100,7 @@ def forward(self, x: torch.Tensor) -> QuantTensor: out, scale, zero_point, bit_width = impl(x) return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled - return QuantTensor(x, training=self.training) + return x class DecoupledWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): @@ -218,4 +218,4 @@ def forward( raise RuntimeError("Internally defined bit-width required") return QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) else: - return QuantTensor(x, training=self.training) + return x diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index a650a7755..678f60736 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -157,16 +157,20 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant if isinstance(y, tuple): y = y[0] - return QuantTensor(y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) + if isinstance(x, QuantTensor): + return QuantTensor( + y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) + else: + return y else: if isinstance(y, tuple): y = y[0] - return QuantTensor(y, training=self.training) + return y else: if isinstance(x, QuantTensor): # passthrough return x else: - return QuantTensor(x, training=self.training) + return x class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index e017f5de3..4d60545dd 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -3,6 +3,7 @@ from abc import ABC from typing import NamedTuple, Optional +import warnings import torch from torch import Tensor @@ -17,8 +18,12 @@ IS_VALID_ATOL = 2e-1 +def _get_dequantize_tensor(input): + return input.value if isinstance(input, QuantTensor) else input + + class QuantTensorBase(NamedTuple): - value: Tensor + qt_value: Optional[Tensor] scale: Optional[Tensor] zero_point: Optional[Tensor] bit_width: Optional[Tensor] @@ -53,7 +58,14 @@ def _is_all_nested_not_none(input_data): class QuantTensor(QuantTensorBase): def __new__( - cls, value, scale=None, zero_point=None, bit_width=None, signed=None, training=None): + cls, + qt_value=None, + scale=None, + zero_point=None, + bit_width=None, + signed=None, + training=None, + _allow_empty=False): if scale is not None and not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, dtype=torch.float) @@ -65,7 +77,27 @@ def __new__( signed = torch.tensor(signed, dtype=torch.bool) if training is not None and not isinstance(training, torch.Tensor): training = torch.tensor(training, dtype=torch.bool) - return super().__new__(cls, value, scale, zero_point, bit_width, signed, training) + + if _allow_empty: + warnings.warn( + "Empty QuantTensor are deprecated and will be removed in a future version") + # elif value is not None and scale is not None and zero_point is not None: + # is_int = torch.allclose(torch.round(int_value), int_value) + # if not is_int: + # quant_tensor = quant_tensor.set(int_value = torch.round(int_value / scale + zero_point)) + # elif int_value is None and value is not None: + # pass + elif not _allow_empty and (scale is None or bit_width is None or zero_point is None): + raise RuntimeError("To create an emtpy QuantTensor, set _allow_empty=True") + + quant_tensor = super().__new__( + cls, qt_value, scale, zero_point, bit_width, signed, training) + return quant_tensor + + @classmethod + def from_fake_quantized(cls, fake_quant_value, scale, zero_point, bit_width, signed, training): + quant_tensor = torch.round(fake_quant_value / scale + zero_point) + return cls(quant_tensor, scale, zero_point, bit_width, signed, training) @property def signed(self): @@ -96,6 +128,13 @@ def __torch_function__(self, func, types, args=(), kwargs=None): def tensor(self): return self.value + @property + def value(self): + if self.is_valid: + return (self.qt_value - self.zero_point) * self.scale + else: + return self.qt_value + @property def is_not_none(self): return ( @@ -110,33 +149,42 @@ def _pre_round_int_value(self): @property def is_valid(self): - if self.is_not_none: - with torch.no_grad(): - pre_round_int_value = self._pre_round_int_value - rounded_int_value = torch.round(pre_round_int_value) - is_int = torch.isclose( - pre_round_int_value, rounded_int_value, atol=IS_VALID_ATOL).all() - if self.bit_width >= 2: - if self.signed: - is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() - is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all() - else: - is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all() - is_lower_b = (0. <= rounded_int_value).all() - return (is_int & is_upper_b & is_lower_b).item() - else: # binary case - unique_vals = rounded_int_value.unique( - sorted=False, return_counts=False, return_inverse=False) - is_binary = unique_vals.view(-1).size()[0] == 2 - is_signed = (unique_vals < 0.).any().item() - sign_match = is_signed == self.signed - return is_int.item() and is_binary and sign_match + if torch.allclose(self.qt_value.to(torch.int).to(self.qt_value.dtype), + self.qt_value, + rtol=0, + atol=0): + return True else: return False + # return True + # # if self.is_not_none: + # # with torch.no_grad(): + # # pre_round_int_value = self._pre_round_int_value + # # rounded_int_value = torch.round(pre_round_int_value) + # # is_int = torch.isclose( + # # pre_round_int_value, rounded_int_value, atol=IS_VALID_ATOL).all() + # # if self.bit_width >= 2: + # # if self.signed: + # # is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() + # # is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all() + # # else: + # # is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all() + # # is_lower_b = (0. <= rounded_int_value).all() + # # return (is_int & is_upper_b & is_lower_b).item() + # # else: # binary case + # # unique_vals = rounded_int_value.unique( + # # sorted=False, return_counts=False, return_inverse=False) + # # is_binary = unique_vals.view(-1).size()[0] == 2 + # # is_signed = (unique_vals < 0.).any().item() + # # sign_match = is_signed == self.signed + # # return is_int.item() and is_binary and sign_match + # # else: + # # return False + @property def device(self): - value_device = self.value.device + value_device = self.qt_value.device is_same_device = True for t in [self.scale, self.zero_point, self.bit_width]: if t is not None: @@ -149,14 +197,14 @@ def set(self, **kwargs): return self._replace(**kwargs) def detach_(self): - self.value.detach_() + self.qt_value.detach_() self.scale.detach_() self.zero_point.detach_() self.bit_width.detach_() def detach(self): return QuantTensor( - self.value.detach(), + self.qt_value.detach(), self.scale.detach() if self.scale is not None else None, self.zero_point.detach() if self.zero_point is not None else None, self.bit_width.detach() if self.bit_width is not None else None, @@ -165,7 +213,7 @@ def detach(self): def contiguous(self): return QuantTensor( - self.value.contiguous(), + self.qt_value.contiguous(), self.scale.contiguous() if self.scale is not None else None, self.zero_point.contiguous() if self.zero_point is not None else None, self.bit_width.contiguous() if self.bit_width is not None else None, @@ -174,18 +222,22 @@ def contiguous(self): def int(self, float_datatype=False): if self.is_valid: - int_value = round_ste(self._pre_round_int_value) - if float_datatype: - return int_value - else: - if self.bit_width <= 8. and self.signed_t.item(): - return int_value.to(torch.int8) - elif self.bit_width <= 8. and not self.signed_t.item(): - return int_value.to(torch.uint8) - else: - return int_value.to(torch.int32) + return self.qt_value else: raise RuntimeError(f"QuantTensor not valid.") + # if self.is_valid: + # int_value = round_ste(self._pre_round_int_value) + # if float_datatype: + # return int_value + # else: + # if self.bit_width <= 8. and self.signed_t.item(): + # return int_value.to(torch.int8) + # elif self.bit_width <= 8. and not self.signed_t.item(): + # return int_value.to(torch.uint8) + # else: + # return int_value.to(torch.int32) + # else: + # raise RuntimeError(f"QuantTensor not valid.") @staticmethod def check_input_type(tensor): @@ -201,10 +253,9 @@ def is_zero_zero_point(tensor): return None def check_scaling_factors_same(self, other): - if self.training is not None and self.training: - return True if not torch.allclose(self.scale, other.scale): - raise RuntimeError("Scaling factors are different") + return False + return True def check_zero_points_same(self, other): if self.training is not None and self.training: @@ -221,41 +272,41 @@ def check_sign_same(self, other): raise RuntimeError("Signs are different") def view(self, *args, **kwargs): - return self.set(value=self.value.view(*args, **kwargs)) + return self.set(int_value=self.qt_value.view(*args, **kwargs)) def reshape(self, *args, **kwargs): - return self.set(value=self.value.reshape(*args, **kwargs)) + return self.set(int_value=self.qt_value.reshape(*args, **kwargs)) def flatten(self, *args, **kwargs): - return self.set(value=self.value.flatten(*args, **kwargs)) + return self.set(int_value=self.qt_value.flatten(*args, **kwargs)) def transpose(self, *args, **kwargs): - value = self.value.transpose(*args, **kwargs) + qt_value = self.qt_value.transpose(*args, **kwargs) tensor_meta = { 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} for k, tm in tensor_meta.items(): - if tm is not None and len(value.shape) == len(tm.shape): + if tm is not None and len(qt_value.shape) == len(tm.shape): tensor_meta[k] = tm.transpose(*args, **kwargs) - return self.set(value=value, **tensor_meta) + return self.set(qt_value=qt_value, **tensor_meta) def permute(self, *args, **kwargs): - value = self.value.permute(*args, **kwargs) + int_value = self.qt_value.permute(*args, **kwargs) tensor_meta = { 'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width} for k, tm in tensor_meta.items(): - if tm is not None and len(value.shape) == len(tm.shape): + if tm is not None and len(int_value.shape) == len(tm.shape): tensor_meta[k] = tm.permute(*args, **kwargs) - return self.set(value=value, **tensor_meta) + return self.set(int_value=int_value, **tensor_meta) def size(self, *args, **kwargs): - return self.value.size(*args, **kwargs) + return self.qt_value.size(*args, **kwargs) @property def shape(self): - return self.value.shape + return self.qt_value.shape def dim(self): - return self.value.dim() + return self.qt_value.dim() def add(self, other): return self + other @@ -270,11 +321,17 @@ def cat(tensors, dim, out=None): first_qt = tensors[0] if all([isinstance(qt, QuantTensor) and qt.is_not_none for qt in tensors]): for qt in tensors[1:]: - first_qt.check_scaling_factors_same(qt) - first_qt.check_zero_points_same(qt) - first_qt.check_bit_width_same(qt) - first_qt.check_sign_same(qt) - output_value = torch.cat([qt.value for qt in tensors], dim=dim) + is_output_qt_valid = True + if first_qt.training is not None and first_qt.training: + if not (first_qt.check_scaling_factors_same(qt) and + first_qt.check_zero_points_same(qt) and + first_qt.check_bit_width_same(qt) and first_qt.check_sign_same(qt)): + is_output_qt_valid = False + elif not (first_qt.check_scaling_factors_same(qt) and + first_qt.check_zero_points_same(qt) and + first_qt.check_bit_width_same(qt) and first_qt.check_sign_same(qt)): + raise RuntimeError("Scaling factors are different") + output_training = any([qt.training for qt in tensors]) if output_training: output_scale = sum([qt.scale for qt in tensors]) / len(tensors) @@ -285,25 +342,32 @@ def cat(tensors, dim, out=None): output_zero_point = first_qt.zero_point output_bit_width = first_qt.bit_width output_signed = first_qt.signed # they are the same - return QuantTensor( - value=output_value, + output_value = torch.cat([qt.int() for qt in tensors], + dim=dim) if is_output_qt_valid else torch.cat( + [qt.value for qt in tensors], dim=dim) + + output = QuantTensor( + output_value, scale=output_scale, zero_point=output_zero_point, bit_width=output_bit_width, signed=output_signed, training=output_training) + + return output else: tensors = [qt.value if isinstance(qt, QuantTensor) else qt for qt in tensors] output_value = torch.cat(tensors, dim=dim) return output_value + # Reference: https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types def __neg__(self): neg_value = (-self.int(float_datatype=True) - self.zero_point) * self.scale if self.signed: return QuantTensor( - value=neg_value, + int_value=neg_value, scale=self.scale, zero_point=self.zero_point, bit_width=self.bit_width, @@ -311,7 +375,7 @@ def __neg__(self): training=self.training) else: return QuantTensor( - value=neg_value, + int_value=neg_value, scale=self.scale, zero_point=self.zero_point, bit_width=self.bit_width + 1, @@ -320,7 +384,7 @@ def __neg__(self): def to(self, *args, **kwargs): return QuantTensor( - self.value.to(*args, **kwargs), + self.qt_value.to(*args, **kwargs), self.scale.to(*args, **kwargs) if self.scale is not None else None, self.zero_point.to(*args, **kwargs) if self.zero_point is not None else None, self.bit_width.to(*args, **kwargs) if self.bit_width is not None else None, @@ -329,7 +393,7 @@ def to(self, *args, **kwargs): def cuda(self, *args, **kwargs): return QuantTensor( - self.value.cuda(*args, **kwargs), + self.qt_value.cuda(*args, **kwargs), self.scale.cuda(*args, **kwargs) if self.scale is not None else None, self.zero_point.cuda(*args, **kwargs) if self.zero_point is not None else None, self.bit_width.cuda(*args, **kwargs) if self.bit_width is not None else None, @@ -338,7 +402,7 @@ def cuda(self, *args, **kwargs): def cpu(self, *args, **kwargs): return QuantTensor( - self.value.cpu(*args, **kwargs), + self.qt_value.cpu(*args, **kwargs), self.scale.cpu(*args, **kwargs) if self.scale is not None else None, self.zero_point.cpu(*args, **kwargs) if self.zero_point is not None else None, self.bit_width.cpu(*args, **kwargs) if self.bit_width is not None else None, @@ -347,8 +411,13 @@ def cpu(self, *args, **kwargs): def __add__(self, other): if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: - self.check_scaling_factors_same(other) - output_value = self.value + other.value + is_output_qt_valid = True + if self.training is not None and self.training: + if not self.check_scaling_factors_same(other): + is_output_qt_valid = False + elif not self.check_scaling_factors_same(other): + raise RuntimeError("Scaling factors are different") + output_scale = (self.scale + other.scale) / 2 output_zero_point = self.zero_point + other.zero_point max_val = max_int(signed=self.signed, narrow_range=False, bit_width=self.bit_width) @@ -358,15 +427,16 @@ def __add__(self, other): output_bit_width = ceil_ste(torch.log2(max_val - min_val)) output_signed = self.signed or other.signed output_training = self.training or other.training + output_value = self.int() + other.int( + ) if is_output_qt_valid else self.value + other.value + output = QuantTensor( - value=output_value, + output_value, scale=output_scale, zero_point=output_zero_point, bit_width=output_bit_width, signed=output_signed, training=output_training) - elif isinstance(other, QuantTensor): - output = self.value + other.value else: output = self.value + other return output @@ -388,8 +458,8 @@ def __mul__(self, other): output_zero_point = self.zero_point * other.zero_point else: raise RuntimeError("Zero-points of mul operands are non-zero, not supported.") - output = QuantTensor( - value=output_value, + output = QuantTensor.from_fake_quantized( + output_value, scale=output_scale, zero_point=output_zero_point, bit_width=output_bit_width, @@ -404,9 +474,12 @@ def __mul__(self, other): def __sub__(self, other): return self.__add__(-other) + def __str__(self): + return f"QuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + def __truediv__(self, other): if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: - output_tensor = self.value / other.tensor + output_tensor = self.qt_value / other.tensor output_scale = self.scale / other.scale output_bit_width = self.bit_width - other.bit_width output_signed = self.signed or other.signed @@ -416,7 +489,7 @@ def __truediv__(self, other): else: output_zero_point = None # TODO non-zero zero point output = QuantTensor( - value=output_tensor, + int_value=output_tensor, scale=output_scale, zero_point=output_zero_point, bit_width=output_bit_width, @@ -432,7 +505,7 @@ def __abs__(self): if self.signed: abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale return QuantTensor( - value=abs_value, + int_value=abs_value, scale=self.scale, zero_point=self.zero_point, bit_width=self.bit_width - 1, diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 3b64bca89..6591a01ce 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -24,7 +24,7 @@ def decorator(func): def quant_invariant_handler(fn, inp, *args, **kwargs): out_value = fn(inp.value, *args, **kwargs) if inp.is_not_none: - return inp.set(value=out_value) + return inp.set(qt_value=torch.round(out_value / inp.scale + inp.zero_point)) else: return out_value diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 6f4243741..0cf8dfbaf 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -45,10 +45,10 @@ parser = argparse.ArgumentParser(description='PyTorch ImageNet PTQ Validation') parser.add_argument( '--calibration-dir', - required=True, + required=False, help='Path to folder containing Imagenet calibration folder') parser.add_argument( - '--validation-dir', required=True, help='Path to folder containing Imagenet validation folder') + '--validation-dir', required=False, help='Path to folder containing Imagenet validation folder') parser.add_argument( '--workers', default=8, type=int, help='Number of data loading workers (default: 8)') parser.add_argument( @@ -385,6 +385,11 @@ def main(): print("Applying bias correction:") apply_bias_correction(calib_loader, quant_model) + from brevitas.graph.calibrate import DisableEnableQuantization + disable_enable_quant = DisableEnableQuantization() + # disable_enable_quant.disable_act_quantization(model, False) + # disable_enable_quant.disable_bias_quantization(model, False) + # Validate the quant_model on the validation dataloader print("Starting validation:") validate(val_loader, quant_model) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 58561cbd1..0cb4d3fd4 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -44,7 +44,8 @@ class TestModel(nn.Module): def __init__(self): super(TestModel, self).__init__() - self.act = qnn.QuantIdentity(act_quant=Int8ActPerTensorFixedPoint) + self.act = qnn.QuantIdentity( + act_quant=Int8ActPerTensorFixedPoint, return_quant_tensor=True) def forward(self, x): return self.act(x)