Skip to content

Commit

Permalink
doc: init doc
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Aug 7, 2023
1 parent 2e4ebd5 commit 7ecfaa8
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 13 deletions.
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
# High Granularity Quantization for Ultra-Fast Inference on FPGAs

This file is work-in-progress. Please check back later.
HGQ is a method for quantization aware training of neural works to be deployed on FPGAs, which allows for per-weight and per-activation bitwidth optimization.

Depending on the specific [application](https://arxiv.org/abs/2006.10159), HGQ could achieve up to 10x resource reduction compared to the traditional `AutoQkeras` approach, while maintaining the same accuracy. For some more challenging [tasks](https://arxiv.org/abs/2202.04976), where the model is already under-fitted, HGQ could still improve the performance under the same on-board resource consumption. For more details, please refer to our paper (link coming not too soon).

This repository implements HGQ for `tensorflow.keras` models. It is independent of the [QKeras project](https://github.com/google/qkeras).

Notice: this repository is still under development, and the API might change in the future.

## Installation

`pip install HGQ`, and you are good to go. Note that HGQ requires `python3.10` and `tensorflow>=2.11`.

## Usage Guide

Please refer to the [usage guide](./usage_guide.md) for more details.
This [repo](https://github.com/calad0i/HGQ-demos) contains some use cases for HGQ.

## FAQ

Please refer to the [FAQ](./faq.md) for more details.

## Citation

The paper is not ready. Please check back later.
36 changes: 36 additions & 0 deletions faq.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# FAQs

## What's this?

HGQ is a method for quantization aware training of neural works to be deployed on FPGAs, which allows for per-weight and per-activation bitwidth optimization.

## Why is it useful?

Depending on the specific [application](https://arxiv.org/abs/2006.10159), HGQ could achieve up to 10x resource reduction compared to the traditional `AutoQkeras` approach, while maintaining the same accuracy. For some more challenging [tasks](https://arxiv.org/abs/2202.04976), where the model is already under-fitted, HGQ could still improve the performance under the same on-board resource consumption. For more details, please refer to our paper (link coming not too soon).

## Can I use it?

The following conditions must be met:

1. Your model is competible with `hls4ml` (i.e. it can be converted to HLS C++ code).
2. You are using `Vivado` as your FPGA backend.
- However, other backend MAY work if you don't use heterogeneous activation quantization.
3. You are using `tensorflow.keras` as your model API.
4. Your model is fully connected or convolutional.
- i.e. no RNN, LSTM, etc. transformers should work if you build one with dense and the conv1d for MMM hack.
5. You are using `tensorflow` as your training framework, and you are using `tf.keras` as your model API.
- Supports both `Sequential` and `Functional` keras model API.

If you meet all the above conditions, you can probably use HGQ to quantize your model.

## How do I get started?

Please refer to the [usage guide](./usage_guide.md) for more details.

## What's the status of the project?

The project is still under development. The codebase and documentation are not stable yet, and we are working on it. If you have any questions, please feel free to contact us.

## How do I cite this work?

The paper is not available yet. Please check back later.
6 changes: 3 additions & 3 deletions src/HGQ/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def result_container(self) -> str:
if self.pre_activation_quantizer.rnd_strategy != 3 and not self._has_bias:
fp_max += 1
assert np.sum(kn[int_bits + fp_bits <= 0]
) == 0, f'Bit counting error at {self.name}. This should never happen. Please try again with cuda disabled (2^13 or above will may in error when tensorflow is run with cuda).'
) == 0, f'Bit counting error at {self.name}. Did you forget to call `compute_bops` before passing the model to converter? Or, please try again with cuda disabled (2^13 or above will may in error when tensorflow is run with cuda. If not, this should never happen. Please open an issue at https://github.com/calad0i/HGQ'
return tuple_to_apf((kn_max, int_max, fp_max))

@property
Expand Down Expand Up @@ -243,7 +243,7 @@ def act_container(self) -> str:
mask = int_bits + fp_bits > 0
int_max, fp_max, kn_max = int_bits[mask].max(), fp_bits[mask].max(), kn[mask].max()
assert np.sum(
kn[~mask]) == 0, f'Bit counting error at {self.name}. This should never happen. Please try again with cuda disabled (2^13 or above will may in error when tensorflow is run with cuda).'
kn[~mask]) == 0, f'Bit counting error at {self.name}. Did you forget to call `compute_bops` before passing the model to converter? Or, please try again with cuda disabled (2^13 or above will may in error when tensorflow is run with cuda. If not, this should never happen. Please open an issue at https://github.com/calad0i/HGQ'
return tuple_to_apf((kn_max, int_max, fp_max))

@property
Expand All @@ -252,7 +252,7 @@ def ker_container(self):
int_bits, fp_bits, kn = self.kernel_quantizer.get_bits_exact(self.kernel)
mask = int_bits + fp_bits > 0
assert np.sum(
kn[~mask]) == 0, f'Bit counting error at {self.name}. This should never happen. Please try again with cuda disabled (2^13 or above will may in error when tensorflow is run with cuda).'
kn[~mask]) == 0, f'Bit counting error at {self.name}. Please try again with cuda disabled (2^13 or above will may in error when tensorflow is run with cuda. If not, this should never happen. Please open an issue at https://github.com/calad0i/HGQ'
int_max, fp_max, kn_max = int_bits[mask].max(), fp_bits[mask].max(), kn[mask].max()
return tuple_to_apf((kn_max, int_max, fp_max))

Expand Down
18 changes: 9 additions & 9 deletions src/HGQ/layers/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ def compute_output_shape(self, input_shape):
return input_shape[0]


class HMultply(HLayerBase, _Merge):
# class HMultply(HLayerBase, _Merge):

@tf.function(jit_compile=True)
def forward(self, inputs, training=None, record_minmax=None):
output = inputs[0]
for i in range(1, len(inputs)):
output *= inputs[i]
return self.pre_activation_quantizer(output, training=training, record_minmax=record_minmax) # type: ignore
# @tf.function(jit_compile=True)
# def forward(self, inputs, training=None, record_minmax=None):
# output = inputs[0]
# for i in range(1, len(inputs)):
# output *= inputs[i]
# return self.pre_activation_quantizer(output, training=training, record_minmax=record_minmax) # type: ignore

def compute_output_shape(self, input_shape):
return input_shape[0]
# def compute_output_shape(self, input_shape):
# return input_shape[0]
138 changes: 138 additions & 0 deletions usage_guide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Usage Guide

## Installation

`pip install HGQ`, and you are good to go. Note that HGQ requires `python3.10` and `tensorflow>=2.11`.

## Getting Started

You need at minimal four extra steps to use HGQ in your keras project:

```python

from HGQ import HDense, HQuantize
from HGQ.bops import compute_bops, ResetMinMax
from tensorflow.keras.models import Sequential
from HGQ.hls4ml_hook import convert_from_hgq_model
....

#regularization factor on MBOPs, higher for smaller bitwidth
bops_reg_factor = 1e-5

# The first layer must be quantized, either by using HQuantize or Signature layers.
# The input quantization layer's name must contain 'inp_q' if you want to quantize the input heterogeneously.
# Use only layers provided by HGQ. You can use functional API as well.
# Please refer to the list below in this document for the full list of supported layers.
model = Sequential([
HQuantize(bops_reg_factor=bops_reg_factor, name='inp_q', input_shape=(16)),
HDense(10, activation='relu', bops_reg_factor=bops_reg_factor),
HDense(10, bops_reg_factor=bops_reg_factor),
])

...

callbacks.append(ResetMinMax()) # Reset min/max every epoch, or the estimated MBOPs could be very inaccurate.

model.fit(..., callbacks=callbacks)

...

# Compute the exact MBOPs of the model.
# This step is NOT optional, as it also records the min/max pre-activation for each layer, which is necessary for determine the number of integer bits.
compute_bops(model, X_train, bsz=bsz)

# Convert the model to HLS4ML. Only vivado backend is test so far. Heterogeneous activation will NOT work with other backends. Weight heterogeneity MAY work.
model_hls = convert_from_hgq_model(
model,
'hls4ml_prj',
part='xcvu9p-flga2104-2L-e',
clock_period=5,
bias_accum=None
)

... (standard hls4ml workflow)

```

For a complete example, please refer to this [notebook](https://github.com/calad0i/HGQ-demos/blob/master/minimal/usage_example.ipynb). Also check out the [demo repo](https://github.com/calad0i/HGQ-demos/) for more use cases.

## Configure the HG Quantizer

```python
from HGQ import set_default_kernel_quantizer_config, set_default_pre_activation_quantizer_config
from HGQ import get_default_kernel_quantizer_config, get_default_pre_activation_quantizer_config

# The default quantizers for the pre-activation and kernel are the following:

DEFAULT_KERNEL_QUANTIZER_CONFIG = \
dict(
# initial bitwidth for the floating part
init_bw=2,
# Which dimensions to quantize homogeneously. Accept a tuple of integers, or any of ['all', 'batch', 'none', 'except_last', 'except_1st'].
skip_dims=None,
# How rounding is performed in training. Can choose from ['floor', 'standard_round', 'stochastic_round', 'fast_uniform_noise_injection', 'auto'].
# In testing, 'standard_round' is used for everything except for 'floor'.
# 'auto': 'floor' for layer without bias except HActivation layers, 'standard_round' otherwise.
rnd_strategy='standard_round',
# Whether round bitwidth to integers before applying the rounding. Defaults to True for weights and False for pre-activations.
exact_q_value=True,
dtype=None,
# The bitwidth range for the floating part.
bw_clip=(-23, 23),
# Whether the bitwidth is trainable.
trainable=True,
# Regularization factor on the numerical bitwidth values. Useful for preventing the bitwidth from being too large for activations does not got invlolved in mul ops (e.g. final layer, layer before HActivation, etc...)
regularizer=L1(1e-6),
)


DEFAULT_PRE_ACTIVATION_QUANTIZER_CONFIG = \
dict(init_bw=2,
skip_dims=(0,), # Same to 'batch'. skipping the batch dimension, which should always be homogeneously quantized.
rnd_strategy='standard_round',
exact_q_value=False,
dtype=None,
bw_clip=(-23, 23),
trainable=True,
regularizer=L1(1e-6),
minmax_record=True
)
```

You can set the default quantizer config for the kernel and pre-activation quantizers by calling `set_default_kernel_quantizer_config` and `set_default_pre_activation_quantizer_config`. You can also get the default quantizer config by calling `get_default_kernel_quantizer_config` and `get_default_pre_activation_quantizer_config`.

When changing the quantizer configs for a specific layer, pass the config dict to the layer with `kernel_quantizer_config` or `pre_activation_quantizer_config` keyword.

## Supported Layers

### HG Layers

Layers that (can) do HG quantization on the (pre-)activation values:

`HQuantize`: Quantize the input to the next layer. When used just after the input layer, add the `inp_q` keyword to the name of the layer. The user must use this layer or `Signature` layer directly after the input layer.

`HDense`: Dense layer with HGQ.

`HConv1D/2D`: Convolutional layers with HGQ.

`HActivation`: Similar to the `Activation` layer, but with (heterogeneous) activation

`HAdd`: Element-wise addition with HGQ.

`HBaseLayer`: Base layer for HGQ layers. Do not use this layer directly. Child layers with one input should overload `forward` and `compute_exact_bops` methods in most cases.

### Passive Layers

Layers that do not do HG quantization, but passes extra necessary information necessary for HGQ to the next layer:

`PXXXPoolND`: Pooling layers.

`PFlatten/PReshape`: Flatten/Reshape layers.

`PConcatenate`: Concatenate layers.

`PLayerBase`: Base layer for passive layers. Do not use this layer directly.

### Signature Layer

`Signature`: A special layer that does not do anything, but passes the input to the next layer. This layer is used to indicate the input data to it is already quantized to some specific bitwidth. The user must use this layer or `HQuantize` layer directly after the input layer.

0 comments on commit 7ecfaa8

Please sign in to comment.