Skip to content

Commit

Permalink
Fix (ptq/evaluate): add support for GPFA2Q for evaluate and benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Dec 21, 2023
1 parent adab5f6 commit 6c69a9c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def unique(sequence):
'scale_factor_type': ['float_scale'], # Scale factor type
'weight_mantissa_bit_width': [4],
'weight_exponent_bit_width': [3],
'weight_narrow_range': [False],
'layerwise_first_last_bit_width': [8], # Input and weights bit width for first and last layer
'act_mantissa_bit_width': [4],
'act_exponent_bit_width': [3],
'weight_bit_width': [8], # Weight Bit Width
Expand All @@ -95,10 +97,12 @@ def unique(sequence):
'graph_eq_merge_bias': [True], # Merge bias for Graph Equalization
'act_equalization': ['layerwise'], # Perform Activation Equalization (Smoothquant)
'learned_round': [False], # Enable/Disable Learned Round
'gptq': [True], # Enable/Disable GPTQ
'gptq': [False], # Enable/Disable GPTQ
'gpfq': [False], # Enable/Disable GPFQ
'gpfq_p': [0.75], # GPFQ P
'gptq_act_order': [False], # Use act_order euristics for GPTQ
'gpfa2q': [False], # Enable/Disable GPFA2Q
'gpfq_p': [1.0], # GPFQ P
'gpxq_act_order': [False], # Use act_order euristics for GPxQ
'accumulator_bit_width': [16], # Accumulator bit width, only in combination with GPFA2Q
'act_quant_percentile': [99.999], # Activation Quantization Percentile
'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible
}
Expand Down Expand Up @@ -221,6 +225,8 @@ def ptq_torchvision_models(args):
quant_format=config_namespace.quant_format,
backend=config_namespace.target_backend,
act_bit_width=config_namespace.act_bit_width,
layerwise_first_last_bit_width=config_namespace.layerwise_first_last_bit_width,
weight_narrow_range=config_namespace.weight_narrow_range,
weight_mantissa_bit_width=config_namespace.weight_mantissa_bit_width,
weight_exponent_bit_width=config_namespace.weight_exponent_bit_width,
act_mantissa_bit_width=config_namespace.act_mantissa_bit_width,
Expand All @@ -247,11 +253,25 @@ def ptq_torchvision_models(args):

if config_namespace.gpfq:
print("Performing GPFQ:")
apply_gpfq(calib_loader, quant_model, p=config_namespace.gpfq_p)
apply_gpfq(
calib_loader,
quant_model,
p=config_namespace.gpfq_p,
act_order=config_namespace.gpxq_act_order)

if config_namespace.gpfa2q:
print("Performing GPFA2Q:")
apply_gpfq(
calib_loader,
quant_model,
p=config_namespace.gpfq_p,
act_order=config_namespace.gpxq_act_order,
gpfa2q=config_namespace.gpfa2q,
accumulator_bit_width=config_namespace.accumulator_bit_width)

if config_namespace.gptq:
print("Performing gptq")
apply_gptq(calib_loader, quant_model, config_namespace.gptq_act_order)
apply_gptq(calib_loader, quant_model, config_namespace.gpxq_act_order)

if config_namespace.learned_round:
print("Applying Learned Round:")
Expand Down Expand Up @@ -309,8 +329,10 @@ def validate_config(config_namespace):
if (config_namespace.target_backend == 'fx' or config_namespace.target_backend
== 'layerwise') and config_namespace.bias_bit_width == 16:
is_valid = False
# If GPTQ is disabled, we do not care about the act_order heuristic
if not config_namespace.gptq and config_namespace.gptq_act_order:
# Only one of GPTQ, GPFQ, or GPA2Q can be enabled, or none
multiple_gpxqs = float(config_namespace.gpfq) + float(config_namespace.gptq) + float(
config_namespace.gpfa2q)
if multiple_gpxqs > 1:
is_valid = False

if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx':
Expand All @@ -320,9 +342,12 @@ def validate_config(config_namespace):

if config_namespace.act_param_method == 'mse':
config_namespace.act_quant_percentile = None

if not config_namespace.gpfq:
# gpfq_p is needed for GPFQ and GPFA2Q
if not config_namespace.gpfq and not config_namespace.gpfa2q:
config_namespace.gpfq_p = None
# accumulator bit width is not needed when not GPFA2Q
if not config_namespace.gpfa2q:
config_namespace.accumulator_bit_width = None

if config_namespace.quant_format == 'int':
config_namespace.weight_mantissa_bit_width = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/image
--act_equalization layerwise \
--learned_round False \
--gptq False \
--gptq_act_order False \
--gpxq_act_order False \
--gpfq False \
--gpfq_p None \
--gpfa2q False \
--accumulator_bit_width None \
--uint_sym_act_for_unsigned_values False \
--act_quant_percentile None \
Original file line number Diff line number Diff line change
Expand Up @@ -472,12 +472,17 @@ def apply_gptq(calib_loader, model, act_order=False):
gptq.update()


def apply_gpfq(calib_loader, model, act_order, p=0.25):
def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumulator_bit_width=None):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with gpfq_mode(model, p=p, use_quant_activations=True, act_order=act_order) as gpfq:
with gpfq_mode(model,
p=p,
use_quant_activations=True,
act_order=act_order,
use_gpfa2q=use_gpfa2q,
accumulator_bit_width=accumulator_bit_width) as gpfq:
gpfq_model = gpfq.model
for i in tqdm(range(gpfq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
40 changes: 27 additions & 13 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,9 @@
add_bool_arg(
parser,
'weight-narrow-range',
default=True,
help='Narrow range for weight quantization (default: enabled)')
parser.add_argument(
'--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 0.25)')
default=False,
help='Narrow range for weight quantization (default: disabled)')
parser.add_argument('--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 1.0)')
parser.add_argument(
'--quant-format',
default='int',
Expand Down Expand Up @@ -211,12 +210,16 @@
default=3,
type=int,
help='Exponent bit width used with float quantization for activations (default: 3)')
parser.add_argument(
'--accumulator-bit-width',
default=None,
type=int,
help='Accumulator Bit Width for GPFA2Q (default: None)')
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)')
add_bool_arg(
parser, 'gptq-act-order', default=False, help='GPTQ Act order heuristic (default: disabled)')
add_bool_arg(
parser, 'gpfq-act-order', default=False, help='GPFQ Act order heuristic (default: disabled)')
parser, 'gpxq-act-order', default=False, help='GPxQ Act order heuristic (default: disabled)')
add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)')
add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)')

Expand Down Expand Up @@ -246,8 +249,8 @@ def main():
f"w{args.weight_bit_width}_"
f"{'gptq_' if args.gptq else ''}"
f"{'gpfq_' if args.gpfq else ''}"
f"{'gptq_act_order_' if args.gptq_act_order else ''}"
f"{'gpfq_act_order_' if args.gpfq_act_order else ''}"
f"{'gpfa2q_' if args.gpfa2q else ''}"
f"{'gpxq_act_order_' if args.gpxq_act_order else ''}"
f"{'learned_round_' if args.learned_round else ''}"
f"{'weight_narrow_range_' if args.weight_narrow_range else ''}"
f"{args.bias_bit_width}bias_"
Expand All @@ -268,9 +271,10 @@ def main():
f"Weight bit width: {args.weight_bit_width} - "
f"GPTQ: {args.gptq} - "
f"GPFQ: {args.gpfq} - "
f"GPFA2Q: {args.gpfa2q} - "
f"GPFQ P: {args.gpfq_p} - "
f"GPTQ Act Order: {args.gptq_act_order} - "
f"GPFQ Act Order: {args.gpfq_act_order} - "
f"GPxQ Act Order: {args.gpxq_act_order} - "
f"GPFA2Q Accumulator Bit Width: {args.accumulator_bit_width} - "
f"Learned Round: {args.learned_round} - "
f"Weight narrow range: {args.weight_narrow_range} - "
f"Bias bit width: {args.bias_bit_width} - "
Expand Down Expand Up @@ -367,11 +371,21 @@ def main():

if args.gpfq:
print("Performing GPFQ:")
apply_gpfq(calib_loader, quant_model, p=args.gpfq_p, act_order=args.gpfq_act_order)
apply_gpfq(calib_loader, quant_model, p=args.gpfq_p, act_order=args.gpxq_act_order)

if args.gpfa2q:
print("Performing GPFA2Q:")
apply_gpfq(
calib_loader,
quant_model,
p=args.gpfq_p,
act_order=args.gpxq_act_order,
use_gpfa2q=args.gpfa2q,
accumulator_bit_width=args.accumulator_bit_width)

if args.gptq:
print("Performing GPTQ:")
apply_gptq(calib_loader, quant_model, act_order=args.gptq_act_order)
apply_gptq(calib_loader, quant_model, act_order=args.gpxq_act_order)

if args.learned_round:
print("Applying Learned Round:")
Expand Down

0 comments on commit 6c69a9c

Please sign in to comment.