diff --git a/docs/source/openvino/export.mdx b/docs/source/openvino/export.mdx index 1d0c534193..e25d50fa0c 100644 --- a/docs/source/openvino/export.mdx +++ b/docs/source/openvino/export.mdx @@ -31,7 +31,7 @@ Check out the help for more options: ```text usage: optimum-cli export openvino [-h] -m MODEL [--task TASK] [--framework {pt,tf}] [--trust-remote-code] - [--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}] [--quant-mode {int8,f8e4m3,f8e5m2}] + [--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}] [--quant-mode {int8,f8e4m3,f8e5m2,nf4_f8e4m3}] [--library {transformers,diffusers,timm,sentence_transformers,open_clip}] [--cache_dir CACHE_DIR] [--pad-token-id PAD_TOKEN_ID] [--ratio RATIO] [--sym] [--group-size GROUP_SIZE] [--backup-precision {none,int8_sym,int8_asym}] @@ -67,7 +67,7 @@ Optional arguments: on your local machine arbitrary code present in the model repository. --weight-format {fp32,fp16,int8,int4,mxfp4,nf4} The weight format of the exported model. - --quant-mode {int8,f8e4m3,f8e5m2} + --quant-mode {int8,f8e4m3,f8e5m2,nf4_f8e4m3} Quantization precision mode. This is used for applying full model quantization including activations. --library {transformers,diffusers,timm,sentence_transformers,open_clip} diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index 8d272a693f..75b218677d 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -78,7 +78,7 @@ def parse_args_openvino(parser: "ArgumentParser"): optional_group.add_argument( "--quant-mode", type=str, - choices=["int8", "f8e4m3", "f8e5m2"], + choices=["int8", "f8e4m3", "f8e5m2", "nf4_f8e4m3"], default=None, help=( "Quantization precision mode. This is used for applying full model quantization including activations. " @@ -307,7 +307,14 @@ def parse_args(parser: "ArgumentParser"): def run(self): from ...exporters.openvino.__main__ import infer_task, main_export, maybe_convert_tokenizers from ...exporters.openvino.utils import save_preprocessors - from ...intel.openvino.configuration import _DEFAULT_4BIT_CONFIG, OVConfig, get_default_int4_config + from ...intel.openvino.configuration import ( + _DEFAULT_4BIT_CONFIG, + OVCompressWeightsOptions, + OVConfig, + OVGeneralQuantizationConfig, + OVQuantizeOptions, + get_default_int4_config, + ) if self.args.library is None: # TODO: add revision, subfolder and token to args @@ -342,43 +349,39 @@ def run(self): if no_compression_parameter_provided(self.args) and self.args.weight_format == "int4": quantization_config = get_default_int4_config(self.args.model) else: - is_int8 = self.args.weight_format == "int8" - quantization_config = { - "bits": 8 if is_int8 else 4, - "ratio": 1 if is_int8 else (self.args.ratio or _DEFAULT_4BIT_CONFIG["ratio"]), - "sym": self.args.sym or False, - "group_size": -1 if is_int8 else self.args.group_size, - "all_layers": None if is_int8 else self.args.all_layers, - "dataset": self.args.dataset, - "num_samples": self.args.num_samples, - "quant_method": "awq" if self.args.awq else "default", - "sensitivity_metric": self.args.sensitivity_metric, - "scale_estimation": self.args.scale_estimation, - "gptq": self.args.gptq, - "lora_correction": self.args.lora_correction, - "weight_format": self.args.weight_format, - "backup_precision": self.args.backup_precision, - } + quantization_config = prepare_for_wc_config(self.args, _DEFAULT_4BIT_CONFIG) if quantization_config.get("dataset", None) is not None: quantization_config["trust_remote_code"] = self.args.trust_remote_code ov_config = OVConfig(quantization_config=quantization_config) - else: + elif self.args.quant_mode is not None: if self.args.dataset is None: raise ValueError( "Dataset is required for full quantization. Please provide it with --dataset argument." ) - quantization_config = { - "weight_format": self.args.quant_mode, - "activation_format": self.args.quant_mode, - "bits": 8, - "sym": self.args.sym or False, - "dataset": self.args.dataset, - "num_samples": self.args.num_samples, - "smooth_quant_alpha": self.args.smooth_quant_alpha, - "trust_remote_code": self.args.trust_remote_code, - } + if self.args.quant_mode == "nf4_f8e4m3": + wc_config = prepare_for_wc_config(self.args, _DEFAULT_4BIT_CONFIG) + wc_config["weight_format"] = "nf4" + cw_options = OVCompressWeightsOptions.init_with_format(**wc_config) + + q_config = prepare_for_q_config(self.args) + q_config["activation_format"] = "f8e4m3" + q_options = OVQuantizeOptions.init_with_format(**q_config) + + quantization_config = OVGeneralQuantizationConfig.init_with_format( + bits=8, + sym=self.args.sym, + ignored_scope=None, + num_samples=self.args.num_samples, + dataset=self.args.dataset, + trust_remote_code=self.args.trust_remote_code, + weight_format=self.args.weight_format, + ) + quantization_config.compress_weights_options = cw_options + quantization_config.quantize_options = q_options + else: + quantization_config = prepare_for_q_config(self.args) ov_config = OVConfig(quantization_config=quantization_config) quantization_config = ov_config.quantization_config if ov_config else None @@ -470,3 +473,36 @@ def run(self): library_name=library_name, # **input_shapes, ) + + +def prepare_for_wc_config(args, default_configs): + is_int8 = args.weight_format == "int8" + return { + "bits": 8 if is_int8 else 4, + "ratio": 1 if is_int8 else (args.ratio or default_configs["ratio"]), + "sym": args.sym or False, + "group_size": -1 if is_int8 else args.group_size, + "all_layers": None if is_int8 else args.all_layers, + "dataset": args.dataset, + "num_samples": args.num_samples, + "quant_method": "awq" if args.awq else "default", + "sensitivity_metric": args.sensitivity_metric, + "scale_estimation": args.scale_estimation, + "gptq": args.gptq, + "lora_correction": args.lora_correction, + "weight_format": args.weight_format, + "backup_precision": args.backup_precision, + } + + +def prepare_for_q_config(args): + return { + "weight_format": args.quant_mode, + "activation_format": args.quant_mode, + "bits": 8, + "sym": args.sym or False, + "dataset": args.dataset, + "num_samples": args.num_samples, + "smooth_quant_alpha": args.smooth_quant_alpha, + "trust_remote_code": args.trust_remote_code, + } diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index 59b4b65ddd..ba0a895cbd 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -775,3 +775,348 @@ def to_dict(self) -> Dict[str, Any]: def to_diff_dict(self) -> Dict[str, Any]: return self._to_dict_safe(to_diff_dict=True) + + +class OVCompressWeightsOptions: + def __init__( + self, + mode: str, + ratio: Optional[float] = None, + group_size: Optional[int] = None, + all_layers: Optional[bool] = None, + sensitivity_metric: Optional[str] = None, + awq: Optional[bool] = None, + scale_estimation: Optional[bool] = None, + gptq: Optional[bool] = None, + lora_correction: Optional[bool] = None, + backup_mode: Optional[str] = None, + advanced_parameters: Optional[Dict] = None, + ): + """ + Class containing specific nncf.compress_weights method's options. + Args: + mode (`str`): + Mode for weight compression. Possible values: ['int4_sym', 'int4_asym', 'int8_sym', 'int8_asym', 'e2m1', 'nf4']. + ratio (`float`, *optional*): + The ratio between baseline and backup precisions. + group_size (`int`, *optional*): + The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. + all_layers (`bool`, *optional*): + Defines how many layers are compressed to 4-bits while the rest are kept in 8-bit precision. + sensitivity_metric (`str`, *optional*): + The sensitivity metric for assigning quantization precision to layers. In order to + preserve the accuracy of the model, the more sensitive layers receives a higher precision. + awq (`bool`, *optional*): + Indicates whether to apply a AWQ algorithm. AWQ improves generation quality of INT4-compressed LLMs, but requires + additional time for tuning weights on a calibration dataset. To run AWQ, providing a dataset is + required. + scale_estimation (`bool`, *optional*): + Indicates whether to apply a scale estimation algorithm that minimizes the L2 error between the original and + compressed layers. Providing a dataset is required to run scale estimation. + qptq (`bool`, *optional*): + Whether to apply GPTQ algorithm. GPTQ optimizes compressed weights in a layer-wise fashion to minimize the + difference between activations of a compressed and original layer. Dataset is required to run GPTQ. + lora_correction (`bool`, *optional*): + If True, apply LoRA Correction algorithm. When enabled, this algorithm introduces low-rank adaptation + layers in the model that can recover accuracy after weight compression at some cost of inference latency. + It calculates low-rank matrices via singular value decomposition (SVD) on the difference between the + original and quantized weights. These matrices are iteratively refined by solving a system of linear + equations to improve accuracy. + backup_precision (`str`, *optional*): + Defines a backup precision for mixed-precision weight compression. + - "none" stands for original floating-point precision of the model weights, in this case weights are + retained in their original precision without any quantization. + - "int8_sym" stands for 8-bit integer symmetric quantization without zero point. + - "int8_asym" stands for 8-bit integer asymmetric quantization with zero points per each quantization group. + advanced_parameters(`Dict`, *optional*) + Defines a dictionary with the advanced parameters. + """ + self.mode = mode + self.ratio = ratio + self.group_size = group_size + self.all_layers = all_layers + self.sensitivity_metric = sensitivity_metric + self.awq = awq + self.scale_estimation = scale_estimation + self.gptq = gptq + self.lora_correction = lora_correction + self.backup_mode = backup_mode + + self._nncf_dict = None + + @staticmethod + def init_with_format( + bits: int = 8, + sym: bool = False, + group_size: Optional[int] = None, + ratio: float = 1.0, + all_layers: Optional[bool] = None, + sensitivity_metric: Optional[str] = None, + quant_method: Union[str, OVQuantizationMethod] = OVQuantizationMethod.DEFAULT, + scale_estimation: bool = None, + weight_format: Optional[str] = None, + gptq: bool = None, + lora_correction: bool = None, + backup_precision: Optional[str] = None, + **kwargs, + ): + """ + Method for the backwards-compatible OVCompressWeightsOptions initialization. + All options are the same as those in the OVWeightQuantizationConfig. + """ + signed_bitness = { + 4: "int4", + 8: "int8", + } + mode = weight_format if weight_format else signed_bitness[bits] + if mode in signed_bitness.values(): + mode += "_sym" if sym else "_asym" + mode = mode + + if isinstance(quant_method, str): + awq = quant_method == "awq" + elif isinstance(quant_method, OVQuantizationMethod): + awq = quant_method == OVQuantizationMethod.AWQ + + return OVCompressWeightsOptions( + mode=mode, + ratio=ratio, + group_size=group_size, + all_layers=all_layers, + sensitivity_metric=sensitivity_metric, + awq=awq, + scale_estimation=scale_estimation, + gptq=gptq, + backup_mode=backup_precision, + lora_correction=lora_correction, + ) + + def to_nncf_dict(self) -> Dict[str, Any]: + """ + Returns a dictionary with the NNCF-friendly variables that are ready to use. + """ + if self._nncf_dict: + return self._nncf_dict + + if is_nncf_available(): + mode = nncf.CompressWeightsMode(self.mode) + sensitivity_metric = nncf.SensitivityMetric(self.sensitivity_metric) if self.sensitivity_metric else None + backup_mode = nncf.BackupMode(self.backup_mode) if self.backup_mode else None + self._nncf_dict = { + "mode": mode, + "ratio": self.ratio, + "group_size": self.group_size, + "all_layers": self.all_layers, + "sensitivity_metric": sensitivity_metric, + "awq": self.awq, + "scale_estimation": self.scale_estimation, + "gptq": self.gptq, + "lora_correction": self.lora_correction, + "backup_mode": backup_mode, + } + return self._nncf_dict + + raise ImportError("NNCF is required to execute this method. Please install nncf first.") + + def to_dict(self) -> Dict[str, Any]: + return copy.deepcopy(self.__dict__) + + +class OVQuantizeOptions: + def __init__( + self, + mode: Optional[str] = None, + preset: Optional[str] = None, + target_device: str = "any", + fast_bias_correction: bool = True, + model_type: Optional[str] = None, + advanced_parameters: Optional[Dict] = None, + ): + """ + Class containing specific nncf.quantize method's options. + Args: + mode (`str`, *optional*): + Defines special quantization modes. Possible values: ['fp8_e4m3', 'fp8_e5m2']. + preset (`str`, *optional*): + Quantization presets, usually meaning to enable either a symmetrical or asymmetrical scheme. Possible values: ['performance', 'mixed']. + target_device (`str`, defaults to "any"): + Target device architecture for compression. Possible values: ['any', 'cpu', 'gpu', 'npu', 'cpu_spr']. + fast_bias_correction (`bool`, defaults to True): + Whether to apply fast or full bias correction algorithm. + model_type (`str`, *optional*): + Model type is needed to specify additional patterns in the model. Supported only `transformer` now. + advanced_parameters(`Dict`, *optional*) + Defines a dictionary with the advanced parameters. + Examples of the values: + - overflow_fix (`str`): + Parameter for controlling overflow fix setting. + - smooth_quant_alphas (`dict`): + SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers and + reduces quantization error. + Examples of the values: + - matmul (`float`) + - convolution (`float`) + """ + self.mode = mode + self.preset = preset + self.target_device = target_device + self.fast_bias_correction = fast_bias_correction + self.model_type = model_type + self.advanced_parameters = advanced_parameters + + self._nncf_dict = None + + @staticmethod + def init_with_format( + bits: int = 8, + sym: bool = False, + ignored_scope: Optional[dict] = None, + num_samples: Optional[int] = 300, + model_type: str = "transformer", + fast_bias_correction: bool = True, + overflow_fix: str = "disable", + dataset: Optional[str] = None, + tokenizer: Optional[str] = None, + processor: Optional[str] = None, + trust_remote_code: bool = False, + smooth_quant_alpha: Optional[float] = None, + weight_format: Optional[str] = "int8", + activation_format: Optional[str] = "int8", + **kwargs, + ): + """ + Method for the backwards-compatible OVQuantizeOptions initialization. + All options are the same as those in the OVQuantizationConfig. + """ + preset = "performance" if sym else "mixed" + advanced_parameters = {"overflow_fix": overflow_fix} + if smooth_quant_alpha: + advanced_parameters["smooth_quant_alphas"] = {"matmul": smooth_quant_alpha} + + mode = None + if activation_format: + mode_map = { + "f8e4m3": "fp8_e4m3", + "f8e5m2": "fp8_e5m2", + } + mode = mode_map[activation_format] + preset = "performance" + + return OVQuantizeOptions( + mode=mode, + preset=preset, + target_device="any", + fast_bias_correction=fast_bias_correction, + model_type=model_type, + advanced_parameters=advanced_parameters, + ) + + def to_nncf_dict(self) -> Dict[str, Any]: + """ + Returns a dictionary with the NNCF-friendly variables that are ready to use. + """ + if self._nncf_dict: + return self._nncf_dict + + if is_nncf_available(): + mode = nncf.QuantizationMode(self.mode) if self.mode else None + preset = nncf.QuantizationPreset(self.preset) + target_device = nncf.TargetDevice(self.target_device.upper()) + model_type = nncf.ModelType(self.model_type) if self.model_type else None + advanced_parameters = None + if self.advanced_parameters: + advanced_parameters = nncf.AdvancedQuantizationParameters( + overflow_fix=self.advanced_parameters["overflow_fix"], + ) + if "smooth_quant_alphas" in self.advanced_parameters: + advanced_parameters.smooth_quant_alphas = nncf.AdvancedSmoothQuantParameters( + **self.advanced_parameters["smooth_quant_alphas"] + ) + + self._nncf_dict = { + "mode": mode, + "preset": preset, + "target_device": target_device, + "fast_bias_correction": self.fast_bias_correction, + "model_type": model_type, + "advanced_parameters": advanced_parameters, + } + return self._nncf_dict + + raise ImportError("NNCF is required to execute this method. Please install nncf first.") + + def to_dict(self) -> Dict: + return copy.deepcopy(self.__dict__) + + +class OVGeneralQuantizationConfig(QuantizationConfigMixin): + def __init__( + self, + ignored_scope: Optional[Dict] = None, + num_samples: Optional[int] = None, + compress_weights_options: Optional[OVCompressWeightsOptions] = None, + quantize_options: Optional[OVQuantizeOptions] = None, + ): + """ + Class containing general options for the NNCF-based quantization. + Args: + ignored_scope (`dict`, *optional*): + An ignored scope that defines the list of model nodes to be ignored during quantization. Dictionary + entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class. + num_samples (`int`, *optional*): + The maximum number of samples composing the calibration dataset. + compress_weights_options (`OVCompressWeightsOptions`, *optional*): + See OVCompressWeightsOptions instance. + quantize_options (`OVQuantizeOptions`, *optional*): + See OVQuantizeOptions instance. + """ + self.ignored_scope = ignored_scope + self.num_samples = num_samples + self.compress_weights_options = compress_weights_options + self.quantize_options = quantize_options + self.bits = None + self.sym = None + self.dataset = None + self.tokenizer = None + self.processor = None + self.trust_remote_code = None + self.weight_format = None + + @staticmethod + def init_with_format( + bits: int = 8, + sym: bool = False, + ignored_scope: Optional[dict] = None, + num_samples: Optional[int] = None, + dataset: Optional[Optional[Union[str, List[str]]]] = None, + tokenizer: Optional[str] = None, + processor: Optional[str] = None, + trust_remote_code: bool = False, + weight_format: Optional[str] = None, + ): + """ + Method for the backwards-compatible QuantizationConfigMixin initialization. + All options are the same as those in the QuantizationConfigMixin. + """ + config = OVGeneralQuantizationConfig( + ignored_scope=ignored_scope, + num_samples=num_samples, + ) + config.bits = bits + config.sym = sym + config.dataset = dataset + config.tokenizer = tokenizer + config.processor = processor + config.trust_remote_code = trust_remote_code + config.weight_format = weight_format + return config + + def get_ignored_scope_instance(self) -> "nncf.IgnoredScope": + ignored_scope = copy.deepcopy(self.ignored_scope) if self.ignored_scope else {} + return nncf.IgnoredScope(**ignored_scope) + + def to_dict(self): + result = copy.deepcopy(self.__dict__) + result["compress_weights_options"] = self.compress_weights_options.to_dict() + result["quantize_options"] = self.quantize_options.to_dict() + return result diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index f61c2b93ca..391643a0fe 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -60,6 +60,7 @@ from ..utils.modeling_utils import get_model_device from .configuration import ( OVConfig, + OVGeneralQuantizationConfig, OVQuantizationConfig, OVQuantizationConfigBase, OVQuantizationMethod, @@ -451,7 +452,7 @@ def _quantize_ovbasemodel( else: _weight_only_quantization(self.model.model, quantization_config, calibration_dataset, **kwargs) self.model.request = None - else: + elif isinstance(quantization_config, OVQuantizationConfig): if not isinstance(quantization_config, OVQuantizationConfig): raise ValueError(f"Unsupported type of quantization config: {type(quantization_config)}") @@ -467,6 +468,15 @@ def _quantize_ovbasemodel( ) self.model.model = quantized_model self.model.request = None + else: + if calibration_dataset is None: + raise ValueError("Calibration dataset is required to run quantization.") + + quantized_model = _general_quantization( + self.model.model, quantization_config, calibration_dataset, **kwargs + ) + self.model.model = quantized_model + self.model.request = None if save_directory is not None: self.model.save_pretrained(save_directory) @@ -1187,3 +1197,54 @@ def _hybrid_quantization( **kwargs, ) return quantized_model + + +def _general_quantization( + model: openvino.Model, + quantization_config: OVGeneralQuantizationConfig, + calibration_dataset: nncf.Dataset, + **kwargs, +) -> openvino.Model: + """ + Quantize a model with NNCF in two possible steps: + - weights-only quantization with nncf.compress_weights method. + - full quantization (excluding weights from previous step) with nncf.quantize method. + + Args: + model (`openvino.runtime.Model`): + The OpenVINO Runtime model for applying quantization. + quantization_config (`OVGeneralQuantizationConfig`): + The configuration containing the parameters related to quantization. + calibration_dataset (`nncf.Dataset`): + The dataset used for quantization. + Returns: + The OpenVINO Runtime model with applied quantization. + """ + quantized_model = model + + ignored_scope = quantization_config.get_ignored_scope_instance() + + if quantization_config.compress_weights_options: + ops_with_weights = _collect_ops_with_weights(model) + wc_kwargs = copy.deepcopy(kwargs) + wc_kwargs.update(quantization_config.compress_weights_options.to_nncf_dict()) + quantized_model = nncf.compress_weights( + model, + ignored_scope=ignored_scope, + dataset=calibration_dataset, + subset_size=quantization_config.num_samples, + **wc_kwargs, + ) + ignored_scope.names += ops_with_weights + + if quantization_config.quantize_options: + q_kwargs = copy.deepcopy(kwargs) + q_kwargs.update(quantization_config.quantize_options.to_nncf_dict()) + quantized_model = nncf.quantize( + model, + calibration_dataset, + subset_size=quantization_config.num_samples, + ignored_scope=ignored_scope, + **q_kwargs, + ) + return quantized_model diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 7c64d84d3d..d9e20df772 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -131,6 +131,14 @@ class OVCLIExportTestCase(unittest.TestCase): (13,), (16,), ), + ( + "text-generation", + "llama", + "nf4_f8e4m3", + "--dataset wikitext2 --num-samples 1 --smooth-quant-alpha 0.9 --group-size 16 --trust-remote-code", + (4,), + (14,), + ), ] TEST_4BIT_CONFIGURATIONS = [ @@ -446,7 +454,11 @@ def test_exporters_cli_full_quantization( for i, model in enumerate(models): num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model) self.assertEqual(expected_fake_nodes[i], num_fake_nodes) - self.assertEqual(expected_low_precision_nodes[i], num_weight_nodes[quant_mode]) + weight_types = quant_mode.split("_") + num_weights = 0 + for weight_type in weight_types: + num_weights += num_weight_nodes[weight_type] + self.assertEqual(expected_low_precision_nodes[i], num_weights) def test_exporters_cli_int4_with_local_model_and_default_config(self): with TemporaryDirectory() as tmpdir: