Skip to content

Commit

Permalink
Validate pass config before instantiating the pass (#1553)
Browse files Browse the repository at this point in the history
## Validate pass config before instantiating the pass

* Use of search point should be limited to engine logic only. Rest of
the Olive implementation should receive a validated configuration to
use.
* Validation is for complete configuration and not merely for a search
point.
* Fixed a few issues related to use of BasePassConfig vs.
FullPassConfig.
* Add local caching to OlivePackageConfig for loaded modules.
* Renamed a few variables in engine logic to be explicit about use of
pass config vs. pass run config.

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
shaahji authored Jan 21, 2025
1 parent 92f431e commit be1278c
Show file tree
Hide file tree
Showing 31 changed files with 369 additions and 418 deletions.
8 changes: 6 additions & 2 deletions olive/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from copy import deepcopy
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from olive.common.config_utils import ConfigBase, convert_configs_to_dicts, validate_config
from olive.common.constants import DEFAULT_CACHE_DIR, DEFAULT_WORKFLOW_ID
Expand Down Expand Up @@ -274,7 +274,11 @@ def cache_olive_config(self, olive_config: Dict):
logger.exception("Failed to cache olive config")

def get_output_model_id(
self, pass_name: int, pass_config: dict, input_model_id: str, accelerator_spec: "AcceleratorSpec" = None
self,
pass_name: str,
pass_config: Dict[str, Any],
input_model_id: str,
accelerator_spec: "AcceleratorSpec" = None,
):
run_json = self.get_run_json(pass_name, pass_config, input_model_id, accelerator_spec)
return hash_dict(run_json)[:8]
Expand Down
231 changes: 94 additions & 137 deletions olive/engine/engine.py

Large diffs are not rendered by default.

26 changes: 16 additions & 10 deletions olive/package_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,26 @@
import functools
import importlib
from pathlib import Path
from typing import Dict, List
from typing import TYPE_CHECKING, ClassVar, Dict, List, Type

from olive.common.config_utils import ConfigBase
from olive.common.pydantic_v1 import validator
from olive.common.pydantic_v1 import Field, validator
from olive.passes import PassModuleConfig

if TYPE_CHECKING:
from olive.passes.olive_pass import Pass


class OlivePackageConfig(ConfigBase):
"""Configuration for an Olive package.
passes key is case-insensitive and stored in lowercase.
"""

passes: Dict[str, PassModuleConfig]
extra_dependencies: Dict[str, List[str]]
passes: Dict[str, PassModuleConfig] = Field(default_factory=dict)
extra_dependencies: Dict[str, List[str]] = Field(default_factory=dict)

_pass_modules: ClassVar[Dict[str, Type["Pass"]]] = {}

@validator("passes")
def validate_passes(cls, values):
Expand All @@ -36,12 +41,13 @@ def load_default_config() -> "OlivePackageConfig":
return OlivePackageConfig.parse_file(OlivePackageConfig.get_default_config_path())

def import_pass_module(self, pass_type: str):
pass_module_config = self.get_pass_module_config(pass_type)
module_path, module_name = pass_module_config.module_path.rsplit(".", 1)
module = importlib.import_module(module_path, module_name)
cls = getattr(module, module_name)
pass_module_config.set_class_variables(cls)
return cls
if pass_type not in self._pass_modules:
pass_module_config = self.get_pass_module_config(pass_type)
module_path, module_name = pass_module_config.module_path.rsplit(".", 1)
module = importlib.import_module(module_path, module_name)
self._pass_modules[pass_type] = cls = getattr(module, module_name)
pass_module_config.set_class_variables(cls)
return self._pass_modules[pass_type]

def get_pass_module_config(self, pass_type: str) -> PassModuleConfig:
if "." in pass_type:
Expand Down
87 changes: 48 additions & 39 deletions olive/passes/olive_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
assert config is not None, "Please specify the configuration for the pass."

# NOTE: The :disable_search argument isn't impactful here since the search isn't
# dependent on it. The same parameter in :generate_search_space is what decides
# dependent on it. The same parameter in :generate_config is what decides
# how search points are handled. HEre, Using default values for each config
# parameter in the config class keeps it simple.
config_class, default_config = self.get_config_class(accelerator_spec, True)
Expand All @@ -87,19 +87,12 @@ def __init__(
self.config = config
self._user_module_loader = UserModuleLoader(self.config.get("user_script"), self.config.get("script_dir"))

self._fixed_params = {}
self.search_space = {}
for k, v in self.config.items():
if isinstance(v, SearchParameter):
self.search_space[k] = v
else:
self._fixed_params[k] = v

# Params that are paths [(param_name, required)]
self.path_params = []
for param, param_config in default_config.items():
if param_config.category in (ParamCategory.PATH, ParamCategory.DATA):
self.path_params.append((param, param_config.required, param_config.category))
self.path_params = [
(param, param_config.required, param_config.category)
for param, param_config in default_config.items()
if param_config.category in (ParamCategory.PATH, ParamCategory.DATA)
]

self._initialized = False

Expand All @@ -113,30 +106,32 @@ def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
return True

@classmethod
def generate_search_space(
def generate_config(
cls,
accelerator_spec: AcceleratorSpec,
config: Optional[Union[Dict[str, Any], BasePassConfig]] = None,
config: Optional[Dict[str, Any]] = None,
disable_search: Optional[bool] = False,
) -> Tuple[Type[BasePassConfig], Dict[str, Any]]:
) -> Dict[str, Any]:
"""Generate search space for the pass."""
assert accelerator_spec is not None, "Please specify the accelerator spec for the pass"
config = config or {}

# Get the config class with default value or default search value
config_class, default_config = cls.get_config_class(accelerator_spec, disable_search)

if not disable_search:
# Replace user-provided values with Categorical if user intended to search
config = cls.identify_search_values(config, default_config)
config = cls._identify_search_values(config, default_config)

# Generate the search space by using both default value and default search value and user provided config
config = validate_config(config, config_class)

config = cls._resolve_config(config, default_config)
return cls._init_fixed_and_search_params(config, default_config)
fixed_values, search_params = cls._init_fixed_and_search_params(config, default_config)
return {**fixed_values, **search_params}

@classmethod
def identify_search_values(
def _identify_search_values(
cls,
config: Dict[str, Any],
default_config: Dict[str, PassConfigParam],
Expand Down Expand Up @@ -189,26 +184,41 @@ def default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConf
), f"{param} ending with data_config must be of type DataConfig."
return config

def config_at_search_point(self, point: Dict[str, Any]) -> Dict[str, Any]:
@classmethod
def config_at_search_point(
cls,
point: Dict[str, Any],
accelerator_spec: AcceleratorSpec,
config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Get the configuration for the pass at a specific point in the search space."""
assert set(point.keys()) == set(self.search_space.keys()), "Search point is not in the search space."
config = self._fixed_params.copy()
config.update(**point)
return self._config_class(**config).dict()
assert accelerator_spec is not None, "Please specify the accelerator spec for the pass"

# Get the config class with default search value
config_class, default_config = cls.get_config_class(accelerator_spec)

# Replace user-provided values with Categorical if user intended to search
config = cls._identify_search_values(config or {}, default_config)

# Generate the search space by using both default value and default search value and user provided config
config = validate_config(config, config_class)
config = cls._resolve_config(config, default_config)
fixed_values, search_params = cls._init_fixed_and_search_params(config, default_config)
assert set(point.keys()) == set(search_params.keys()), "Search point is not in the search space."
return {**fixed_values, **search_params, **point}

def validate_search_point(
self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False
@classmethod
def validate_config(
cls,
config: Dict[str, Any],
accelerator_spec: AcceleratorSpec,
disable_search: Optional[bool] = False,
) -> bool:
"""Validate the search point for the pass."""
"""Validate the input config for the pass."""
return True

def run(
self, model: OliveModelHandler, output_model_path: str, point: Optional[Dict[str, Any]] = None
) -> OliveModelHandler:
def run(self, model: OliveModelHandler, output_model_path: str) -> OliveModelHandler:
"""Run the pass on the model at a specific point in the search space."""
point = point or {}
config = self.config_at_search_point(point)

if not self._initialized:
self._initialize()
self._initialized = True
Expand All @@ -218,7 +228,7 @@ def run(
for rank in range(model.num_ranks):
input_ranked_model = model.load_model(rank)
ranked_output_path = Path(output_model_path).with_suffix("") / model.ranked_model_name(rank)
self._run_for_config(input_ranked_model, config, str(ranked_output_path))
self._run_for_config(input_ranked_model, self.config, str(ranked_output_path))

# ranked model don't have their own model_attributes, they are just part of the distributed model
# which has the model_attributes
Expand All @@ -235,7 +245,7 @@ def run(
component_names = []
for component_name, component_model in model.get_model_components():
component_output_path = Path(output_model_path).with_suffix("") / component_name
output_model_component = self._run_for_config(component_model, config, str(component_output_path))
output_model_component = self._run_for_config(component_model, self.config, str(component_output_path))
output_model_component.model_attributes = (
output_model_component.model_attributes or component_model.model_attributes
)
Expand All @@ -245,7 +255,7 @@ def run(
output_model = CompositeModelHandler(components, component_names)
output_model.model_attributes = output_model.model_attributes or model.model_attributes
else:
output_model = self._run_for_config(model, config, output_model_path)
output_model = self._run_for_config(model, self.config, output_model_path)
# assumption: the model attributes from passes, if any, are more important than
# the input model attributes, we should not update/extend anymore outside of the pass run
output_model.model_attributes = output_model.model_attributes or model.model_attributes
Expand Down Expand Up @@ -443,8 +453,7 @@ def _init_fixed_and_search_params(
assert not cyclic_search_space(search_space), "Search space is cyclic."
# TODO(jambayk): better error message, e.g. which parameters are invalid, how they are invalid
assert SearchSpace({"search_space": search_space}).size() > 0, "There are no valid points in the search space."

return {**fixed_params, **search_space}
return fixed_params, search_space

@classmethod
def _resolve_search_parameter(cls, param: SearchParameter, fixed_params: Dict[str, Any]) -> Any:
Expand Down Expand Up @@ -519,5 +528,5 @@ def create_pass_from_dict(
if accelerator_spec is None:
accelerator_spec = DEFAULT_CPU_ACCELERATOR

config = pass_cls.generate_search_space(accelerator_spec, config, disable_search)
config = pass_cls.generate_config(accelerator_spec, config, disable_search)
return pass_cls(accelerator_spec, config, host_device)
20 changes: 13 additions & 7 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import logging
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union

import onnx
import transformers
Expand Down Expand Up @@ -92,15 +92,21 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
),
}

def validate_search_point(
self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False
@classmethod
def validate_config(
cls,
config: Dict[str, Any],
accelerator_spec: AcceleratorSpec,
disable_search: Optional[bool] = False,
) -> bool:
if with_fixed_value:
search_point = self.config_at_search_point(search_point or {})
precision = search_point.get("precision")
if not super().validate_config(config, accelerator_spec, disable_search):
return False

config_cls, _ = cls.get_config_class(accelerator_spec, disable_search)
config = config_cls(**config)

# if device is GPU, but user choose CPU EP, the is_cpu should be True
if (precision == ModelBuilder.Precision.FP16) and not (
if (config.precision == ModelBuilder.Precision.FP16) and not (
accelerator_spec.accelerator_type == Device.GPU
and accelerator_spec.execution_provider != "CPUExecutionProvider"
):
Expand Down
32 changes: 17 additions & 15 deletions olive/passes/onnx/nvmo_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# --------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union

import onnx
import torch
Expand Down Expand Up @@ -73,47 +73,49 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
),
}

def validate_search_point(
self,
search_point: Dict[str, Any],
@classmethod
def validate_config(
cls,
config: Dict[str, Any],
accelerator_spec: AcceleratorSpec,
with_fixed_value: bool = False,
disable_search: Optional[bool] = False,
) -> bool:
if with_fixed_value:
search_point = self.config_at_search_point(search_point or {})
if not super().validate_config(config, accelerator_spec, disable_search):
return False

config_cls, _ = cls.get_config_class(accelerator_spec, disable_search)
config = config_cls(**config)

# Validate Precision
if search_point.get("precision") != NVModelOptQuantization.Precision.INT4:
if config.precision != NVModelOptQuantization.Precision.INT4:
logger.error("Only INT4 quantization is supported.")
return False

# Validate Algorithm
if search_point.get("algorithm") not in [
NVModelOptQuantization.Algorithm.AWQ.value,
]:
if config.algorithm not in [NVModelOptQuantization.Algorithm.AWQ.value]:
logger.error("Only 'AWQ' algorithm is supported.")
return False

# Validate Calibration
if search_point.get("calibration") not in [
if config.calibration not in [
NVModelOptQuantization.Calibration.AWQ_LITE.value,
NVModelOptQuantization.Calibration.AWQ_CLIP.value,
]:
logger.error("Calibration method must be either 'awq_lite' or 'awq_clip'.")
return False

random_calib = search_point.get("random_calib_data", False)
random_calib = config.random_calib_data or False
if not isinstance(random_calib, bool):
logger.error("'random_calib_data' must be a boolean value.")
return False

tokenizer_dir = search_point.get("tokenizer_dir", "")
tokenizer_dir = config.tokenizer_dir or ""
if not random_calib and not tokenizer_dir:
logger.error("'tokenizer_dir' must be specified when 'random_calib_data' is False.")
return False

# Optional: Validate 'tokenizer_dir' if necessary
if not search_point.get("tokenizer_dir"):
if not config.tokenizer_dir:
logger.warning("Tokenizer directory 'tokenizer_dir' is not specified.")

return True
Expand Down
19 changes: 13 additions & 6 deletions olive/passes/onnx/optimum_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

from olive.hardware.accelerator import AcceleratorSpec
from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler
Expand Down Expand Up @@ -48,13 +48,20 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
),
}

def validate_search_point(
self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False
@classmethod
def validate_config(
cls,
config: Dict[str, Any],
accelerator_spec: AcceleratorSpec,
disable_search: Optional[bool] = False,
) -> bool:
if with_fixed_value:
search_point = self.config_at_search_point(search_point or {})
if not super().validate_config(config, accelerator_spec, disable_search):
return False

config_cls, _ = cls.get_config_class(accelerator_spec, disable_search)
config = config_cls(**config)

if search_point.get("fp16") and search_point.get("device") != "cuda":
if config.fp16 and config.device != "cuda":
logger.info("OptimumConversion: fp16 is set to True, but device is not set to cuda.")
return False

Expand Down
Loading

0 comments on commit be1278c

Please sign in to comment.