From be1278ca4f879a2b80de1d299a14031522fda5d9 Mon Sep 17 00:00:00 2001 From: shaahji <96227573+shaahji@users.noreply.github.com> Date: Tue, 21 Jan 2025 14:45:42 -0800 Subject: [PATCH] Validate pass config before instantiating the pass (#1553) ## Validate pass config before instantiating the pass * Use of search point should be limited to engine logic only. Rest of the Olive implementation should receive a validated configuration to use. * Validation is for complete configuration and not merely for a search point. * Fixed a few issues related to use of BasePassConfig vs. FullPassConfig. * Add local caching to OlivePackageConfig for loaded modules. * Renamed a few variables in engine logic to be explicit about use of pass config vs. pass run config. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- olive/cache.py | 8 +- olive/engine/engine.py | 231 +++++++----------- olive/package_config.py | 26 +- olive/passes/olive_pass.py | 87 ++++--- olive/passes/onnx/model_builder.py | 20 +- olive/passes/onnx/nvmo_quantization.py | 32 +-- olive/passes/onnx/optimum_conversion.py | 19 +- olive/passes/onnx/quantization.py | 31 ++- olive/passes/onnx/session_params_tuning.py | 26 +- olive/passes/onnx/transformer_optimization.py | 33 +-- olive/passes/pass_config.py | 36 +-- olive/passes/pytorch/capture_split_info.py | 21 +- olive/passes/pytorch/torch_trt_conversion.py | 15 +- olive/systems/azureml/aml_system.py | 6 +- olive/systems/docker/docker_system.py | 10 +- .../isolated_ort/isolated_ort_system.py | 5 +- olive/systems/local.py | 7 +- olive/systems/olive_system.py | 3 +- .../python_environment_system.py | 8 +- olive/workflows/run/config.py | 4 + olive/workflows/run/run.py | 93 ++----- test/unit_test/engine/test_engine.py | 16 +- .../passes/common/test_user_script.py | 2 +- .../passes/onnx/test_bnb_quantization.py | 6 +- .../passes/onnx/test_optimum_conversion.py | 4 +- .../onnx/test_transformer_optimization.py | 12 +- .../passes/test_pass_serialization.py | 4 +- .../systems/docker/test_docker_system.py | 2 - .../test_python_environment_system.py | 12 +- test/unit_test/systems/test_local.py | 2 +- test/unit_test/utils.py | 6 +- 31 files changed, 369 insertions(+), 418 deletions(-) diff --git a/olive/cache.py b/olive/cache.py index 360f9db0b..23384adcd 100644 --- a/olive/cache.py +++ b/olive/cache.py @@ -11,7 +11,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from olive.common.config_utils import ConfigBase, convert_configs_to_dicts, validate_config from olive.common.constants import DEFAULT_CACHE_DIR, DEFAULT_WORKFLOW_ID @@ -274,7 +274,11 @@ def cache_olive_config(self, olive_config: Dict): logger.exception("Failed to cache olive config") def get_output_model_id( - self, pass_name: int, pass_config: dict, input_model_id: str, accelerator_spec: "AcceleratorSpec" = None + self, + pass_name: str, + pass_config: Dict[str, Any], + input_model_id: str, + accelerator_spec: "AcceleratorSpec" = None, ): run_json = self.get_run_json(pass_name, pass_config, input_model_id, accelerator_spec) return hash_dict(run_json)[:8] diff --git a/olive/engine/engine.py b/olive/engine/engine.py index dc330c55d..06ba28137 100644 --- a/olive/engine/engine.py +++ b/olive/engine/engine.py @@ -26,6 +26,7 @@ from olive.logging import enable_filelog from olive.model import ModelConfig from olive.package_config import OlivePackageConfig +from olive.strategy.search_parameter import SearchParameter from olive.strategy.search_strategy import SearchStrategy, SearchStrategyConfig from olive.systems.common import SystemType from olive.systems.system_config import SystemConfig @@ -52,7 +53,7 @@ def __init__( search_strategy: Optional[Union[Dict[str, Any], SearchStrategyConfig]] = None, host: Optional[Union[Dict[str, Any], "SystemConfig"]] = None, target: Optional[Union[Dict[str, Any], "SystemConfig"]] = None, - evaluator: Optional[Union[Dict[str, Any], "OliveEvaluatorConfig"]] = None, + evaluator: Optional[Union[Dict[str, Any], OliveEvaluatorConfig]] = None, cache_config: Optional[Union[Dict[str, Any], CacheConfig]] = None, plot_pareto_frontier: bool = False, no_artifacts: bool = False, @@ -83,14 +84,9 @@ def __init__( self.skip_saving_artifacts = no_artifacts self.azureml_client_config = azureml_client_config - # dictionary of passes - self.pass_config = OrderedDict() - - # {"pass_name": {"pass": pass, "host": host, "evaluator": evaluator} - self.passes = OrderedDict() - self.pass_flows = None - self.pass_flows_search_spaces = None - + self.pass_run_configs: Dict[str, Dict[str, Any]] = OrderedDict() + self.pass_flows: List[List[str]] = [] + self.search_spaces: List[List[Tuple[str, Dict[str, SearchParameter]]]] = [] self.footprints = defaultdict(Footprint) self._initialized = False @@ -111,15 +107,17 @@ def initialize(self, log_to_file: bool = False, log_severity_level: int = 1): if self.target_config.type != SystemType.AzureML: if self.evaluator_config: self.evaluator_config = self.cache.prepare_resources_for_local(self.evaluator_config) - for pass_config in self.pass_config.values(): - if pass_config["evaluator"]: - pass_config["evaluator"] = self.cache.prepare_resources_for_local(pass_config["evaluator"]) - for pass_config in self.pass_config.values(): - host_type = pass_config["host"].system_type if pass_config["host"] else self.host_config.type - if host_type == SystemType.AzureML: - continue - pass_config["config"] = self.cache.prepare_resources_for_local(pass_config["config"]) + for pass_run_config in self.pass_run_configs.values(): + if pass_run_config["evaluator"]: + pass_run_config["evaluator"] = self.cache.prepare_resources_for_local(pass_run_config["evaluator"]) + + for pass_run_config in self.pass_run_configs.values(): + host_type = pass_run_config["host"].system_type if pass_run_config["host"] else self.host_config.type + if host_type != SystemType.AzureML: + pass_run_config["input_config"] = self.cache.prepare_resources_for_local( + pass_run_config["input_config"] + ) self.set_pass_flows(self.pass_flows) self._initialized = True @@ -130,11 +128,11 @@ def register( config: Dict[str, Any] = None, name: str = None, host: "OliveSystem" = None, - evaluator_config: "OliveEvaluatorConfig" = None, + evaluator_config: OliveEvaluatorConfig = None, ): """Register a pass configuration so that it could be instantiated and executed later.""" - if name is not None: - assert name not in self.passes, f"Pass with name {name} already registered" + if name: + assert name not in self.pass_run_configs, f"Pass with name {name} already registered" else: idx = 0 while True: @@ -142,41 +140,18 @@ def register( if idx > 0: name = f"{name}_{idx}" idx += 1 - if name not in self.pass_config: + if name not in self.pass_run_configs: break pass_type_name = pass_type if isinstance(pass_type, str) else pass_type.__name__ logger.debug("Registering pass %s", pass_type_name) - pass_type = self.olive_config.import_pass_module(pass_type_name) - - self.pass_config[name] = { - "type": pass_type, - "config": config or {}, + self.pass_run_configs[name] = { + "type": pass_type_name, + "input_config": config or {}, "host": host, "evaluator": evaluator_config, } - def register_pass( - self, p: "Pass", name: str = None, host: "OliveSystem" = None, evaluator_config: "OliveEvaluatorConfig" = None - ): - """Register a pass instance.""" - if name is not None: - assert name not in self.passes, f"Pass with name {name} already registered" - else: - idx = 0 - while True: - name = p.__class__.__name__ - if idx > 0: - name = f"{name}_{idx}" - idx += 1 - if name not in self.passes: - break - - if not self.search_strategy and len(p.search_space) > 0: - raise ValueError(f"Search strategy is None but pass {name} has search space") - - self.passes[name] = {"pass": p, "host": host, "evaluator": evaluator_config} - def set_pass_flows(self, pass_flows: List[List[str]] = None): """Construct pass flows from a list of pass names. @@ -184,10 +159,7 @@ def set_pass_flows(self, pass_flows: List[List[str]] = None): pass_flows: a list of pass names, each pass name is a string. """ - if not pass_flows: - self.pass_flows = [list(self.pass_config.keys())] if self.pass_config else [] - else: - self.pass_flows = pass_flows + self.pass_flows = pass_flows or [list(self.pass_run_configs.keys())] def run( self, @@ -268,17 +240,15 @@ def run( accelerator_spec, ) - if run_result is None: - continue - - outputs[accelerator_spec] = run_result + if run_result: + outputs[accelerator_spec] = run_result for accelerator_spec in self.footprints: logger.info("Run history for %s:", accelerator_spec) run_history = self.footprints[accelerator_spec].summarize_run_history() self.dump_run_history(run_history, output_subdirs[accelerator_spec] / "run_history.txt") - if packaging_config and self.passes: + if packaging_config and self.pass_run_configs: # TODO(trajep): should we support packaging pytorch model? logger.info("Package top ranked %d models as artifacts", sum(len(f.nodes) for f in outputs.values())) generate_output_artifacts( @@ -294,7 +264,7 @@ def run( # TODO(team): refactor output structure # Do not change condition order. For no search, values of outputs are MetricResult # Consolidate the output structure for search and no search mode - if outputs and self.passes and not next(iter(outputs.values())).check_empty_nodes(): + if outputs and self.pass_run_configs and not next(iter(outputs.values())).check_empty_nodes(): best_node: FootprintNode = get_best_candidate_node(outputs, self.footprints) self.cache.save_model(model_id=best_node.model_id, output_dir=output_dir, overwrite=True) if len(accelerator_output_dir_list) > 1 and self.skip_saving_artifacts: @@ -310,8 +280,12 @@ def run_accelerator( evaluate_input_model: bool, accelerator_spec: "AcceleratorSpec", ): - # generate search space and initialize the passes for each hardware accelerator - self.setup_passes(accelerator_spec) + # Setup pass configs + self._setup_pass_configs(accelerator_spec) + + # generate search space + self._setup_search_spaces(accelerator_spec) + # hash the input model input_model_id = input_model_config.get_model_id() if input_model_id == LOCAL_INPUT_MODEL_ID and self.cache.enable_shared_cache: @@ -336,7 +310,8 @@ def run_accelerator( with results_path.open("w") as f: json.dump(results.to_json(), f, indent=4) logger.info("Saved evaluation results of input model to %s", results_path) - if not self.passes: + + if not self.pass_run_configs: logger.debug("No passes registered, return input model evaluation results.") return results @@ -365,38 +340,30 @@ def run_accelerator( return output_footprint def get_host_device(self): - if self.host_config.config.accelerators: - # for host device, we will always use the first accelerator device - return self.host_config.config.accelerators[0].device - else: - return None + # for host device, we will always use the first accelerator device + return self.host_config.config.accelerators[0].device if self.host_config.config.accelerators else None + + def _setup_pass_configs(self, accelerator_spec: "AcceleratorSpec"): + disable_search = self.search_strategy is None + for pass_run_config in self.pass_run_configs.values(): + pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_run_config["type"]) + pass_run_config["config"] = pass_cls.generate_config( + accelerator_spec, pass_run_config["input_config"], disable_search + ) + + def _setup_search_spaces(self, accelerator_spec: "AcceleratorSpec"): + self.search_spaces.clear() + if self.search_strategy is None: + return - def setup_passes(self, accelerator_spec: "AcceleratorSpec"): - host_device = self.get_host_device() - # clean the passes - self.passes.clear() - for name, config in self.pass_config.items(): - pass_cls: Type[Pass] = config["type"] - pass_cfg = config["config"] - pass_cfg = pass_cls.generate_search_space(accelerator_spec, pass_cfg, self.search_strategy is None) - p = pass_cls(accelerator_spec, pass_cfg, host_device) - self.register_pass(p, name=name, host=config["host"], evaluator_config=config["evaluator"]) - - # list of passes starting from the first pass with non-empty search space - # These passes will be added to the search space - self.pass_flows_search_spaces = [] for pass_flow in self.pass_flows: - pass_search_spaces = [] + pass_search_spaces: List[Tuple[str, Dict[str, SearchParameter]]] = [] for pass_name in pass_flow: - p: Pass = self.passes[pass_name]["pass"] - pass_search_spaces.append((pass_name, p.search_space)) - self.pass_flows_search_spaces.append(pass_search_spaces) - - def reset_passes(self): - """Cleanup the passes.""" - self.passes.clear() - self.pass_config.clear() - self.pass_flows = [] + pass_run_config = self.pass_run_configs[pass_name] + pass_search_spaces.append( + (pass_name, {k: v for k, v in pass_run_config["config"].items() if isinstance(v, SearchParameter)}) + ) + self.search_spaces.append(pass_search_spaces) def run_no_search( self, @@ -406,17 +373,12 @@ def run_no_search( output_dir: Path, ): """Run all the registered Olive pass flows in no-search mode.""" - for pass_item in self.passes.values(): - if len(pass_item["pass"].search_space) > 0: - pass_name = pass_item["name"] - raise ValueError(f"Pass {pass_name} has search space but search strategy is None") - output_model_dir = Path(output_dir) output_model_ids = [] for pass_flow in self.pass_flows: # search point is empty since there is no search - passes_to_run = [(pass_id, {}) for pass_id in pass_flow] + passes_to_run = [(pass_name, {}) for pass_name in pass_flow] # run all the passes in the pass flow logger.debug("Running %s with no search ...", pass_flow) @@ -461,7 +423,7 @@ def run_search( ): """Run all the registered Olive passes in search model where search strategy is not None.""" # get objective_dict - evaluator_config = self.evaluator_for_pass(list(self.passes.keys())[-1]) + evaluator_config = self.evaluator_for_pass(list(self.pass_run_configs.keys())[-1]) if evaluator_config is None: raise ValueError("No evaluator provided for the last pass") @@ -472,7 +434,7 @@ def run_search( self.footprints[accelerator_spec].record_objective_dict(objective_dict) # initialize the search strategy - self.search_strategy.initialize(self.pass_flows_search_spaces, input_model_id, objective_dict) + self.search_strategy.initialize(self.search_spaces, input_model_id, objective_dict) output_model_num = self.search_strategy.get_output_model_num() # record start time @@ -490,10 +452,7 @@ def run_search( # get the model id of the first input model model_id = next_step["model_id"] - if model_id == input_model_id: - model_config = input_model_config - else: - model_config = self._load_model(model_id) + model_config = input_model_config if model_id == input_model_id else self._load_model(model_id) logger.debug("Step %d with search point %s ...", iter_num, next_step["search_point"]) @@ -639,18 +598,13 @@ def resolve_goals( return resolved_goals - def host_for_pass(self, pass_id: str): - host = self.passes[pass_id]["host"] - if host is None: - return self.host - return host + def host_for_pass(self, pass_name: str) -> "OliveSystem": + host: SystemConfig = self.pass_run_configs[pass_name]["host"] + return host or self.host - def evaluator_for_pass(self, pass_id: str): + def evaluator_for_pass(self, pass_name: str) -> OliveEvaluatorConfig: """Return evaluator for the given pass.""" - e = self.passes[pass_id]["evaluator"] - if e is None: - return self.evaluator_config - return e + return self.pass_run_configs[pass_name]["evaluator"] or self.evaluator_config def _cache_model(self, model_id: str, model: Union[ModelConfig, str], check_object: bool = True): # TODO(trajep): move model/pass run/evaluation cache into footprints @@ -681,11 +635,11 @@ def _run_passes( should_prune = False # run all the passes in the step model_ids = [] - pass_id = None + pass_name = None - for pass_id, pass_search_point in passes: + for pass_name, pass_search_point in passes: model_config, model_id = self._run_pass( - pass_id, + pass_name, pass_search_point, model_config, model_id, @@ -693,7 +647,7 @@ def _run_passes( ) if model_config in PRUNED_CONFIGS: should_prune = True - logger.debug("Pruned for pass %s", pass_id) + logger.debug("Pruned for pass %s", pass_name) break model_ids.append(model_id) @@ -702,7 +656,7 @@ def _run_passes( if not should_prune: # evaluate the model - evaluator_config = self.evaluator_for_pass(pass_id) + evaluator_config = self.evaluator_for_pass(pass_name) if not self.search_strategy and evaluator_config is None: # skip evaluation if no search and no evaluator signal = None @@ -718,35 +672,38 @@ def _run_passes( def _run_pass( self, - pass_id: str, + pass_name: str, pass_search_point: Dict[str, Any], input_model_config: ModelConfig, input_model_id: str, accelerator_spec: "AcceleratorSpec", ): """Run a pass on the input model.""" - # pass run_start_time = datetime.now().timestamp() - p: Pass = self.passes[pass_id]["pass"] - pass_name = p.__class__.__name__ - logger.info("Running pass %s:%s %s", pass_id, pass_name, pass_search_point) - pass_config = p.config_at_search_point(pass_search_point) - pass_config = p.serialize_config(pass_config) - output_model_config = None + pass_run_config: Dict[str, Any] = self.pass_run_configs[pass_name] + pass_type_name = pass_run_config["type"] + + logger.info("Running pass %s:%s %s", pass_name, pass_type_name, pass_search_point) # check whether the config is valid - if not p.validate_search_point(pass_search_point, accelerator_spec, with_fixed_value=True): - logger.warning("Invalid search point, prune") - output_model_config = INVALID_CONFIG + pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_run_config["type"]) + pass_config = pass_cls.config_at_search_point(pass_search_point, accelerator_spec, pass_run_config["config"]) + if not pass_cls.validate_config(pass_config, accelerator_spec, self.search_strategy is None): + logger.warning("Invalid config, pruned.") + logger.debug(pass_config) # no need to record in footprint since there was no run and thus no valid/failed model # invalid configs are also not cached since the same config can be valid for other accelerator specs # a pass can be accelerator agnostic but still have accelerator specific invalid configs # this helps reusing cached models for different accelerator specs - return output_model_config, None + return INVALID_CONFIG, None + + p: Pass = pass_cls(accelerator_spec, pass_config, self.get_host_device()) + pass_config = p.serialize_config(pass_config, check_object=True) + output_model_config = None # load run from cache if it exists run_accel = None if p.is_accelerator_agnostic(accelerator_spec) else accelerator_spec - output_model_id = self.cache.get_output_model_id(pass_name, pass_config, input_model_id, run_accel) + output_model_id = self.cache.get_output_model_id(pass_type_name, pass_config, input_model_id, run_accel) run_cache = self.cache.load_run_from_model_id(output_model_id) if run_cache: logger.debug("Loading model from cache ...") @@ -759,7 +716,7 @@ def _run_pass( output_model_config.to_json() if output_model_config != FAILED_CONFIG else {"is_pruned": True} ), parent_model_id=input_model_id, - from_pass=pass_name, + from_pass=pass_type_name, pass_run_config=pass_config, start_time=run_start_time, end_time=datetime.now().timestamp(), @@ -771,7 +728,7 @@ def _run_pass( if input_model_config.config.get("shared_cache", False): input_model_config = self.cache.download_shared_cache_model(input_model_config, input_model_id) - host = self.host_for_pass(pass_id) + host = self.host_for_pass(pass_name) if host.system_type != SystemType.AzureML: input_model_config = self.cache.prepare_resources_for_local(input_model_config) @@ -779,12 +736,12 @@ def _run_pass( if p.run_on_target: if self.target.system_type == SystemType.IsolatedORT: logger.warning( - "Cannot run pass %s on IsolatedORT target, will use the host to run the pass.", pass_id + "Cannot run pass %s on IsolatedORT target, will use the host to run the pass.", pass_name ) else: host = self.target - output_model_config = host.run_pass(p, input_model_config, output_model_path, pass_search_point) + output_model_config = host.run_pass(p, input_model_config, output_model_path) except OlivePassError: logger.exception("Pass run_pass failed") output_model_config = FAILED_CONFIG @@ -801,20 +758,20 @@ def _run_pass( raise # rethrow the exception if no search is performed run_end_time = datetime.now().timestamp() - logger.info("Pass %s:%s finished in %f seconds", pass_id, pass_name, run_end_time - run_start_time) + logger.info("Pass %s:%s finished in %f seconds", pass_name, pass_type_name, run_end_time - run_start_time) # cache model self._cache_model(output_model_id, output_model_config) # cache run - self.cache.cache_run(pass_name, pass_config, input_model_id, output_model_id, run_accel) + self.cache.cache_run(pass_type_name, pass_config, input_model_id, output_model_id, run_accel) # footprint model and run self.footprints[accelerator_spec].record( model_id=output_model_id, model_config=output_model_config.to_json() if output_model_config != FAILED_CONFIG else {"is_pruned": True}, parent_model_id=input_model_id, - from_pass=pass_name, + from_pass=pass_type_name, pass_run_config=pass_config, start_time=run_start_time, end_time=run_end_time, @@ -850,7 +807,7 @@ def _evaluate_model( self, model_config: ModelConfig, model_id: str, - evaluator_config: "OliveEvaluatorConfig", + evaluator_config: OliveEvaluatorConfig, accelerator_spec: "AcceleratorSpec", ): """Evaluate a model.""" diff --git a/olive/package_config.py b/olive/package_config.py index 8e2118db6..cf9090cc8 100644 --- a/olive/package_config.py +++ b/olive/package_config.py @@ -5,12 +5,15 @@ import functools import importlib from pathlib import Path -from typing import Dict, List +from typing import TYPE_CHECKING, ClassVar, Dict, List, Type from olive.common.config_utils import ConfigBase -from olive.common.pydantic_v1 import validator +from olive.common.pydantic_v1 import Field, validator from olive.passes import PassModuleConfig +if TYPE_CHECKING: + from olive.passes.olive_pass import Pass + class OlivePackageConfig(ConfigBase): """Configuration for an Olive package. @@ -18,8 +21,10 @@ class OlivePackageConfig(ConfigBase): passes key is case-insensitive and stored in lowercase. """ - passes: Dict[str, PassModuleConfig] - extra_dependencies: Dict[str, List[str]] + passes: Dict[str, PassModuleConfig] = Field(default_factory=dict) + extra_dependencies: Dict[str, List[str]] = Field(default_factory=dict) + + _pass_modules: ClassVar[Dict[str, Type["Pass"]]] = {} @validator("passes") def validate_passes(cls, values): @@ -36,12 +41,13 @@ def load_default_config() -> "OlivePackageConfig": return OlivePackageConfig.parse_file(OlivePackageConfig.get_default_config_path()) def import_pass_module(self, pass_type: str): - pass_module_config = self.get_pass_module_config(pass_type) - module_path, module_name = pass_module_config.module_path.rsplit(".", 1) - module = importlib.import_module(module_path, module_name) - cls = getattr(module, module_name) - pass_module_config.set_class_variables(cls) - return cls + if pass_type not in self._pass_modules: + pass_module_config = self.get_pass_module_config(pass_type) + module_path, module_name = pass_module_config.module_path.rsplit(".", 1) + module = importlib.import_module(module_path, module_name) + self._pass_modules[pass_type] = cls = getattr(module, module_name) + pass_module_config.set_class_variables(cls) + return self._pass_modules[pass_type] def get_pass_module_config(self, pass_type: str) -> PassModuleConfig: if "." in pass_type: diff --git a/olive/passes/olive_pass.py b/olive/passes/olive_pass.py index 9fd400401..1f9cbb061 100644 --- a/olive/passes/olive_pass.py +++ b/olive/passes/olive_pass.py @@ -75,7 +75,7 @@ def __init__( assert config is not None, "Please specify the configuration for the pass." # NOTE: The :disable_search argument isn't impactful here since the search isn't - # dependent on it. The same parameter in :generate_search_space is what decides + # dependent on it. The same parameter in :generate_config is what decides # how search points are handled. HEre, Using default values for each config # parameter in the config class keeps it simple. config_class, default_config = self.get_config_class(accelerator_spec, True) @@ -87,19 +87,12 @@ def __init__( self.config = config self._user_module_loader = UserModuleLoader(self.config.get("user_script"), self.config.get("script_dir")) - self._fixed_params = {} - self.search_space = {} - for k, v in self.config.items(): - if isinstance(v, SearchParameter): - self.search_space[k] = v - else: - self._fixed_params[k] = v - # Params that are paths [(param_name, required)] - self.path_params = [] - for param, param_config in default_config.items(): - if param_config.category in (ParamCategory.PATH, ParamCategory.DATA): - self.path_params.append((param, param_config.required, param_config.category)) + self.path_params = [ + (param, param_config.required, param_config.category) + for param, param_config in default_config.items() + if param_config.category in (ParamCategory.PATH, ParamCategory.DATA) + ] self._initialized = False @@ -113,30 +106,32 @@ def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: return True @classmethod - def generate_search_space( + def generate_config( cls, accelerator_spec: AcceleratorSpec, - config: Optional[Union[Dict[str, Any], BasePassConfig]] = None, + config: Optional[Dict[str, Any]] = None, disable_search: Optional[bool] = False, - ) -> Tuple[Type[BasePassConfig], Dict[str, Any]]: + ) -> Dict[str, Any]: """Generate search space for the pass.""" assert accelerator_spec is not None, "Please specify the accelerator spec for the pass" + config = config or {} # Get the config class with default value or default search value config_class, default_config = cls.get_config_class(accelerator_spec, disable_search) if not disable_search: # Replace user-provided values with Categorical if user intended to search - config = cls.identify_search_values(config, default_config) + config = cls._identify_search_values(config, default_config) # Generate the search space by using both default value and default search value and user provided config config = validate_config(config, config_class) config = cls._resolve_config(config, default_config) - return cls._init_fixed_and_search_params(config, default_config) + fixed_values, search_params = cls._init_fixed_and_search_params(config, default_config) + return {**fixed_values, **search_params} @classmethod - def identify_search_values( + def _identify_search_values( cls, config: Dict[str, Any], default_config: Dict[str, PassConfigParam], @@ -189,26 +184,41 @@ def default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConf ), f"{param} ending with data_config must be of type DataConfig." return config - def config_at_search_point(self, point: Dict[str, Any]) -> Dict[str, Any]: + @classmethod + def config_at_search_point( + cls, + point: Dict[str, Any], + accelerator_spec: AcceleratorSpec, + config: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: """Get the configuration for the pass at a specific point in the search space.""" - assert set(point.keys()) == set(self.search_space.keys()), "Search point is not in the search space." - config = self._fixed_params.copy() - config.update(**point) - return self._config_class(**config).dict() + assert accelerator_spec is not None, "Please specify the accelerator spec for the pass" + + # Get the config class with default search value + config_class, default_config = cls.get_config_class(accelerator_spec) + + # Replace user-provided values with Categorical if user intended to search + config = cls._identify_search_values(config or {}, default_config) + + # Generate the search space by using both default value and default search value and user provided config + config = validate_config(config, config_class) + config = cls._resolve_config(config, default_config) + fixed_values, search_params = cls._init_fixed_and_search_params(config, default_config) + assert set(point.keys()) == set(search_params.keys()), "Search point is not in the search space." + return {**fixed_values, **search_params, **point} - def validate_search_point( - self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False + @classmethod + def validate_config( + cls, + config: Dict[str, Any], + accelerator_spec: AcceleratorSpec, + disable_search: Optional[bool] = False, ) -> bool: - """Validate the search point for the pass.""" + """Validate the input config for the pass.""" return True - def run( - self, model: OliveModelHandler, output_model_path: str, point: Optional[Dict[str, Any]] = None - ) -> OliveModelHandler: + def run(self, model: OliveModelHandler, output_model_path: str) -> OliveModelHandler: """Run the pass on the model at a specific point in the search space.""" - point = point or {} - config = self.config_at_search_point(point) - if not self._initialized: self._initialize() self._initialized = True @@ -218,7 +228,7 @@ def run( for rank in range(model.num_ranks): input_ranked_model = model.load_model(rank) ranked_output_path = Path(output_model_path).with_suffix("") / model.ranked_model_name(rank) - self._run_for_config(input_ranked_model, config, str(ranked_output_path)) + self._run_for_config(input_ranked_model, self.config, str(ranked_output_path)) # ranked model don't have their own model_attributes, they are just part of the distributed model # which has the model_attributes @@ -235,7 +245,7 @@ def run( component_names = [] for component_name, component_model in model.get_model_components(): component_output_path = Path(output_model_path).with_suffix("") / component_name - output_model_component = self._run_for_config(component_model, config, str(component_output_path)) + output_model_component = self._run_for_config(component_model, self.config, str(component_output_path)) output_model_component.model_attributes = ( output_model_component.model_attributes or component_model.model_attributes ) @@ -245,7 +255,7 @@ def run( output_model = CompositeModelHandler(components, component_names) output_model.model_attributes = output_model.model_attributes or model.model_attributes else: - output_model = self._run_for_config(model, config, output_model_path) + output_model = self._run_for_config(model, self.config, output_model_path) # assumption: the model attributes from passes, if any, are more important than # the input model attributes, we should not update/extend anymore outside of the pass run output_model.model_attributes = output_model.model_attributes or model.model_attributes @@ -443,8 +453,7 @@ def _init_fixed_and_search_params( assert not cyclic_search_space(search_space), "Search space is cyclic." # TODO(jambayk): better error message, e.g. which parameters are invalid, how they are invalid assert SearchSpace({"search_space": search_space}).size() > 0, "There are no valid points in the search space." - - return {**fixed_params, **search_space} + return fixed_params, search_space @classmethod def _resolve_search_parameter(cls, param: SearchParameter, fixed_params: Dict[str, Any]) -> Any: @@ -519,5 +528,5 @@ def create_pass_from_dict( if accelerator_spec is None: accelerator_spec = DEFAULT_CPU_ACCELERATOR - config = pass_cls.generate_search_space(accelerator_spec, config, disable_search) + config = pass_cls.generate_config(accelerator_spec, config, disable_search) return pass_cls(accelerator_spec, config, host_device) diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index 8e73188c2..7e7b9cf04 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -8,7 +8,7 @@ import json import logging from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import onnx import transformers @@ -92,15 +92,21 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon ), } - def validate_search_point( - self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False + @classmethod + def validate_config( + cls, + config: Dict[str, Any], + accelerator_spec: AcceleratorSpec, + disable_search: Optional[bool] = False, ) -> bool: - if with_fixed_value: - search_point = self.config_at_search_point(search_point or {}) - precision = search_point.get("precision") + if not super().validate_config(config, accelerator_spec, disable_search): + return False + + config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) + config = config_cls(**config) # if device is GPU, but user choose CPU EP, the is_cpu should be True - if (precision == ModelBuilder.Precision.FP16) and not ( + if (config.precision == ModelBuilder.Precision.FP16) and not ( accelerator_spec.accelerator_type == Device.GPU and accelerator_spec.execution_provider != "CPUExecutionProvider" ): diff --git a/olive/passes/onnx/nvmo_quantization.py b/olive/passes/onnx/nvmo_quantization.py index e54866bc5..b44feb65f 100644 --- a/olive/passes/onnx/nvmo_quantization.py +++ b/olive/passes/onnx/nvmo_quantization.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import onnx import torch @@ -73,47 +73,49 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon ), } - def validate_search_point( - self, - search_point: Dict[str, Any], + @classmethod + def validate_config( + cls, + config: Dict[str, Any], accelerator_spec: AcceleratorSpec, - with_fixed_value: bool = False, + disable_search: Optional[bool] = False, ) -> bool: - if with_fixed_value: - search_point = self.config_at_search_point(search_point or {}) + if not super().validate_config(config, accelerator_spec, disable_search): + return False + + config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) + config = config_cls(**config) # Validate Precision - if search_point.get("precision") != NVModelOptQuantization.Precision.INT4: + if config.precision != NVModelOptQuantization.Precision.INT4: logger.error("Only INT4 quantization is supported.") return False # Validate Algorithm - if search_point.get("algorithm") not in [ - NVModelOptQuantization.Algorithm.AWQ.value, - ]: + if config.algorithm not in [NVModelOptQuantization.Algorithm.AWQ.value]: logger.error("Only 'AWQ' algorithm is supported.") return False # Validate Calibration - if search_point.get("calibration") not in [ + if config.calibration not in [ NVModelOptQuantization.Calibration.AWQ_LITE.value, NVModelOptQuantization.Calibration.AWQ_CLIP.value, ]: logger.error("Calibration method must be either 'awq_lite' or 'awq_clip'.") return False - random_calib = search_point.get("random_calib_data", False) + random_calib = config.random_calib_data or False if not isinstance(random_calib, bool): logger.error("'random_calib_data' must be a boolean value.") return False - tokenizer_dir = search_point.get("tokenizer_dir", "") + tokenizer_dir = config.tokenizer_dir or "" if not random_calib and not tokenizer_dir: logger.error("'tokenizer_dir' must be specified when 'random_calib_data' is False.") return False # Optional: Validate 'tokenizer_dir' if necessary - if not search_point.get("tokenizer_dir"): + if not config.tokenizer_dir: logger.warning("Tokenizer directory 'tokenizer_dir' is not specified.") return True diff --git a/olive/passes/onnx/optimum_conversion.py b/olive/passes/onnx/optimum_conversion.py index b0c5e03cd..153bf6825 100644 --- a/olive/passes/onnx/optimum_conversion.py +++ b/olive/passes/onnx/optimum_conversion.py @@ -5,7 +5,7 @@ import logging from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from olive.hardware.accelerator import AcceleratorSpec from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler @@ -48,13 +48,20 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon ), } - def validate_search_point( - self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False + @classmethod + def validate_config( + cls, + config: Dict[str, Any], + accelerator_spec: AcceleratorSpec, + disable_search: Optional[bool] = False, ) -> bool: - if with_fixed_value: - search_point = self.config_at_search_point(search_point or {}) + if not super().validate_config(config, accelerator_spec, disable_search): + return False + + config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) + config = config_cls(**config) - if search_point.get("fp16") and search_point.get("device") != "cuda": + if config.fp16 and config.device != "cuda": logger.info("OptimumConversion: fp16 is set to True, but device is not set to cuda.") return False diff --git a/olive/passes/onnx/quantization.py b/olive/passes/onnx/quantization.py index 8c2aa342d..42504b90b 100644 --- a/olive/passes/onnx/quantization.py +++ b/olive/passes/onnx/quantization.py @@ -7,7 +7,7 @@ from copy import deepcopy from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Optional, Union import onnx from packaging import version @@ -324,17 +324,24 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon config.update(get_external_data_config()) return config - def validate_search_point( - self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False + @classmethod + def validate_config( + cls, + config: Dict[str, Any], + accelerator_spec: AcceleratorSpec, + disable_search: Optional[bool] = False, ) -> bool: - config = search_point or {} - if with_fixed_value: - config = self.config_at_search_point(search_point) - if config["quant_mode"] == "static": + if not super().validate_config(config, accelerator_spec, disable_search): + return False + + config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) + config = config_cls(**config) + + if config.quant_mode == "static": if ( - config["weight_type"] == "QInt8" - and config["activation_type"] == "QInt8" - and config["quant_format"] == "QOperator" + config.weight_type == "QInt8" + and config.activation_type == "QInt8" + and config.quant_format == "QOperator" ): # S8S8 with QOperator will be slow on x86-64 CPUs and should be avoided in general. # https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection @@ -342,7 +349,7 @@ def validate_search_point( logger.warning( "S8S8 with QOperator will be slow on x86-64 CPUs and should be avoided in general, try QDQ instead." ) - if config["EnableSubgraph"] is True: + if config.EnableSubgraph is True: logger.info("EnableSubgraph is not supported for static quantization.") return False return True @@ -364,7 +371,7 @@ def _run_for_config( if is_static: assert config["data_config"], "data_config is required for static quantization." # whether to prepare qnn config - # we do the version check here and not in `validate_search_point` since search point validation + # we do the version check here and not in `validate_config` since search point validation # is done by the engine. Unless the host is local system, the ort version of the host is # not known by the engine when the search point is validated. if config["prepare_qnn_config"] and version.parse(OrtVersion) < version.parse("1.17.0"): diff --git a/olive/passes/onnx/session_params_tuning.py b/olive/passes/onnx/session_params_tuning.py index a52ade95e..9e69f4ec1 100644 --- a/olive/passes/onnx/session_params_tuning.py +++ b/olive/passes/onnx/session_params_tuning.py @@ -7,7 +7,7 @@ import logging import tempfile from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import onnxruntime as ort @@ -170,15 +170,22 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon ), } - def validate_search_point( - self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False + @classmethod + def validate_config( + cls, + config: Dict[str, Any], + accelerator_spec: AcceleratorSpec, + disable_search: Optional[bool] = False, ) -> bool: """Validate the search point for the pass.""" - config = self.config_at_search_point(search_point or {}) + if not super().validate_config(config, accelerator_spec, disable_search): + return False + + config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) + config_cls.__config__.extra = Extra.allow + config = config_cls(**config) # Rename the search parameters with atomic/singular names for clarity - self._config_class.__config__.extra = Extra.allow - config = self._config_class(**config) config.execution_provider = config.providers_list config.provider_options = config.provider_options_list config.execution_mode = config.execution_mode_list @@ -246,14 +253,11 @@ def validate_search_point( return False # TODO(myguo): we need disable the following check when we enable cache in perf tuning. - if ( - config.execution_provider != self.accelerator_spec.execution_provider - and not config.force_evaluate_other_eps - ): + if config.execution_provider != accelerator_spec.execution_provider and not config.force_evaluate_other_eps: logger.warning( "Ignore perf tuning for EP %s since current pass EP is %s", config.execution_provider, - self.accelerator_spec.execution_provider, + accelerator_spec.execution_provider, ) return False diff --git a/olive/passes/onnx/transformer_optimization.py b/olive/passes/onnx/transformer_optimization.py index 6a6406d95..17b51e963 100644 --- a/olive/passes/onnx/transformer_optimization.py +++ b/olive/passes/onnx/transformer_optimization.py @@ -5,7 +5,7 @@ import logging from copy import deepcopy from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import onnx @@ -135,15 +135,23 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon config.update(get_external_data_config()) return config - def validate_search_point( - self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False + @classmethod + def validate_config( + cls, + config: Dict[str, Any], + accelerator_spec: AcceleratorSpec, + disable_search: Optional[bool] = False, ) -> bool: + if not super().validate_config(config, accelerator_spec, disable_search): + return False + from onnxruntime import __version__ as OrtVersion from packaging import version - if with_fixed_value: - search_point = self.config_at_search_point(search_point or {}) - if search_point.get("float16"): + config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) + config = config_cls(**config) + + if config.float16: if accelerator_spec.execution_provider == "TensorrtExecutionProvider": logger.info( "TensorRT has its own float16 implementation, please avoid to use float16 in transformers " @@ -153,21 +161,16 @@ def validate_search_point( if accelerator_spec.execution_provider == "CPUExecutionProvider": logger.info("CPUExecutionProvider does not support float16 very well, please avoid to use float16.") return False - if not search_point.get("float16") and search_point.get("use_gqa"): + if not config.float16 and config.use_gqa: logger.info("use_gqa is only supported when float16 is True.") return False - if search_point.get("use_gpu") and accelerator_spec.execution_provider == "CPUExecutionProvider": + if config.use_gpu and accelerator_spec.execution_provider == "CPUExecutionProvider": logger.info("CPUExecutionProvider does not support GPU inference, please avoid to use use_gpu.") return False - if search_point.get("only_onnxruntime") and search_point.get("opt_level") <= 0: + if config.only_onnxruntime and config.opt_level <= 0: logger.info("Please specify a positive value for opt_level when only_onnxruntime is True") return False - if ( - search_point.get("opt_level") == 0 - and search_point.get("only_onnxruntime") - and search_point.get("num_heads") == 0 - and search_point.get("hidden_size") == 0 - ): + if config.opt_level == 0 and config.only_onnxruntime and config.num_heads == 0 and config.hidden_size == 0: if version.parse(OrtVersion) <= version.parse("1.16.0"): logger.info( "Ignore this search point because the issue https://github.com/microsoft/onnxruntime/issues/17254" diff --git a/olive/passes/pass_config.py b/olive/passes/pass_config.py index f97dfbaae..82695f9d9 100644 --- a/olive/passes/pass_config.py +++ b/olive/passes/pass_config.py @@ -92,24 +92,6 @@ def get_user_script_data_config( DEFAULT_SET = set(PassParamDefault) -class AbstractPassConfig(NestedConfig): - """Base class for pass configuration.""" - - type: str = Field(description="The type of the pass.") - config: Dict[str, Any] = Field( - None, - description=( - "The configuration of the pass. Values for required parameters must be provided. For optional parameters," - " default values or searchable values (if available and search is not disabled) will be used if not" - " provided." - ), - ) - - @validator("type", pre=True) - def validate_type(cls, v): - return validate_lowercase(v) - - class BasePassConfig(ConfigBase): @validator("*", pre=True) @@ -129,6 +111,24 @@ def _validate_search_parameter(cls, v): return v +class AbstractPassConfig(NestedConfig): + """Base class for pass configuration.""" + + type: str = Field(description="The type of the pass.") + config: Dict[str, Any] = Field( + None, + description=( + "The configuration of the pass. Values for required parameters must be provided. For optional parameters," + " default values or searchable values (if available and search is not disabled) will be used if not" + " provided." + ), + ) + + @validator("type", pre=True) + def validate_type(cls, v): + return validate_lowercase(v) + + def create_config_class( pass_type: str, default_config: Dict[str, PassConfigParam], diff --git a/olive/passes/pytorch/capture_split_info.py b/olive/passes/pytorch/capture_split_info.py index 2b46eb31e..d59a02d99 100644 --- a/olive/passes/pytorch/capture_split_info.py +++ b/olive/passes/pytorch/capture_split_info.py @@ -6,7 +6,7 @@ import logging from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import numpy as np @@ -63,17 +63,24 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon ), } - def validate_search_point( - self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False + @classmethod + def validate_config( + cls, + config: Dict[str, Any], + accelerator_spec: AcceleratorSpec, + disable_search: Optional[bool] = False, ) -> bool: - if with_fixed_value: - search_point = self.config_at_search_point(search_point or {}) + if not super().validate_config(config, accelerator_spec, disable_search): + return False + + config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) + config = config_cls(**config) - if search_point.get("num_splits") is None and search_point.get("cost_model") is None: + if config.num_splits is None and config.cost_model is None: logger.info("One of num_splits or cost_model is required.") return False - if search_point.get("cost_model") is not None and accelerator_spec.memory is None: + if config.cost_model is not None and accelerator_spec.memory is None: logger.info("Accelerator memory is required if cost_model is provided.") return False diff --git a/olive/passes/pytorch/torch_trt_conversion.py b/olive/passes/pytorch/torch_trt_conversion.py index 85a698327..53da22111 100644 --- a/olive/passes/pytorch/torch_trt_conversion.py +++ b/olive/passes/pytorch/torch_trt_conversion.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch @@ -67,12 +67,19 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon ), } - def validate_search_point( - self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False + @classmethod + def validate_config( + cls, + config: Dict[str, Any], + accelerator_spec: AcceleratorSpec, + disable_search: Optional[bool] = False, ) -> bool: + if not super().validate_config(config, accelerator_spec, disable_search): + return False + # since the run will leverage the host device to move the model to device, # we need to check if the host device is GPU - if self.host_device != Device.GPU: + if accelerator_spec.accelerator_type != Device.GPU: logger.info("TorchTRTConversion only supports GPU.") return False return True diff --git a/olive/systems/azureml/aml_system.py b/olive/systems/azureml/aml_system.py index 2b53a2239..a286dbbf2 100644 --- a/olive/systems/azureml/aml_system.py +++ b/olive/systems/azureml/aml_system.py @@ -248,16 +248,12 @@ def run_pass( the_pass: "Pass", model_config: ModelConfig, output_model_path: str, - point: Optional[Dict[str, Any]] = None, ) -> ModelConfig: - """Run the pass on the model at a specific point in the search space.""" + """Run the pass on the model.""" ml_client = self.azureml_client_config.create_client() # serialize pass - point = point or {} - config = the_pass.config_at_search_point(point) pass_config = the_pass.to_json(check_object=True) - pass_config["config"].update(the_pass.serialize_config(config, check_object=True)) with tempfile.TemporaryDirectory() as tempdir: pipeline_job = self._create_pipeline_for_pass(tempdir, model_config.to_json(check_object=True), pass_config) diff --git a/olive/systems/docker/docker_system.py b/olive/systems/docker/docker_system.py index efafb5961..29560fe41 100644 --- a/olive/systems/docker/docker_system.py +++ b/olive/systems/docker/docker_system.py @@ -108,11 +108,10 @@ def run_pass( the_pass: "Pass", model_config: "ModelConfig", output_model_path: str, - point: Optional[Dict[str, Any]] = None, ) -> "ModelConfig": - """Run the pass on the model at a specific point in the search space.""" + """Run the pass on the model.""" with tempfile.TemporaryDirectory() as tempdir: - return self._run_pass_container(Path(tempdir), the_pass, model_config, output_model_path, point) + return self._run_pass_container(Path(tempdir), the_pass, model_config, output_model_path) def _run_pass_container( self, @@ -120,13 +119,8 @@ def _run_pass_container( the_pass: "Pass", model_config: "ModelConfig", output_model_path: str, - point: Optional[Dict[str, Any]] = None, ) -> "ModelConfig": - point = point or {} - config = the_pass.config_at_search_point(point) - pass_config = the_pass.to_json(check_object=True) - pass_config["config"].update(the_pass.serialize_config(config, check_object=True)) volumes_list = [] runner_output_path = "runner_output" diff --git a/olive/systems/isolated_ort/isolated_ort_system.py b/olive/systems/isolated_ort/isolated_ort_system.py index 4b07729cc..a16d5952f 100644 --- a/olive/systems/isolated_ort/isolated_ort_system.py +++ b/olive/systems/isolated_ort/isolated_ort_system.py @@ -8,7 +8,7 @@ import shutil import tempfile from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union import numpy as np @@ -66,9 +66,8 @@ def run_pass( the_pass: "Pass", model_config: "ModelConfig", output_model_path: str, - point: Optional[Dict[str, Any]] = None, ) -> "ModelConfig": - """Run the pass on the model at a specific point in the search space.""" + """Run the pass on the model.""" logger.warning("IsolatedORTSystem does not support running passes.") raise NotImplementedError diff --git a/olive/systems/local.py b/olive/systems/local.py index 8e8fadc8f..355dd747b 100644 --- a/olive/systems/local.py +++ b/olive/systems/local.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, List from olive.hardware.accelerator import AcceleratorSpec, Device from olive.model import ModelConfig @@ -23,11 +23,10 @@ def run_pass( the_pass: "Pass", model_config: ModelConfig, output_model_path: str, - point: Optional[Dict[str, Any]] = None, ) -> ModelConfig: - """Run the pass on the model at a specific point in the search space.""" + """Run the pass on the model.""" model = model_config.create_model() - output_model = the_pass.run(model, output_model_path, point) + output_model = the_pass.run(model, output_model_path) return ModelConfig.from_json(output_model.to_json()) def evaluate_model( diff --git a/olive/systems/olive_system.py b/olive/systems/olive_system.py index b9f28da89..7a39a1b90 100644 --- a/olive/systems/olive_system.py +++ b/olive/systems/olive_system.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union from olive.common.config_utils import validate_config from olive.systems.common import AcceleratorConfig, SystemType @@ -46,7 +46,6 @@ def run_pass( the_pass: "Pass", model_config: "ModelConfig", output_model_path: str, - point: Optional[Dict[str, Any]] = None, ) -> "ModelConfig": """Run the pass on the model at a specific point in the search space.""" raise NotImplementedError diff --git a/olive/systems/python_environment/python_environment_system.py b/olive/systems/python_environment/python_environment_system.py index bb4400262..52f2528f1 100644 --- a/olive/systems/python_environment/python_environment_system.py +++ b/olive/systems/python_environment/python_environment_system.py @@ -9,7 +9,7 @@ import shutil import tempfile from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union from olive.common.constants import OS from olive.common.utils import run_subprocess @@ -110,13 +110,9 @@ def run_pass( the_pass: "Pass", model_config: ModelConfig, output_model_path: str, - point: Optional[Dict[str, Any]] = None, ) -> ModelConfig: - """Run the pass on the model at a specific point in the search space.""" - point = point or {} - config = the_pass.config_at_search_point(point) + """Run the pass on the model.""" pass_config = the_pass.to_json(check_object=True) - pass_config["config"].update(the_pass.serialize_config(config, check_object=True)) config_jsons = { "model_config": model_config.to_json(check_object=True), "pass_config": pass_config, diff --git a/olive/workflows/run/config.py b/olive/workflows/run/config.py index ebb7919c9..5265e58e0 100644 --- a/olive/workflows/run/config.py +++ b/olive/workflows/run/config.py @@ -270,6 +270,10 @@ def validate_pass_search(cls, v, values): ) return v + @validator("pass_flows", pre=True) + def validate_pass_flows(cls, v, values): + return v or [] + @validator("workflow_host", pre=True) def validate_workflow_host(cls, v, values): if v is None: diff --git a/olive/workflows/run/run.py b/olive/workflows/run/run.py index 4de79240c..67c0b691b 100644 --- a/olive/workflows/run/run.py +++ b/olive/workflows/run/run.py @@ -10,7 +10,6 @@ from pathlib import Path from typing import Generator, List, Optional, Union -from olive.auto_optimizer import AutoOptimizer from olive.common.utils import set_tempdir from olive.logging import set_default_logger_severity, set_ort_logger_severity, set_verbosity_info from olive.package_config import OlivePackageConfig @@ -147,9 +146,6 @@ def run_engine(package_config: OlivePackageConfig, run_config: RunConfig): except Exception: logger.warning("ORT log severity level configuration ignored since the module isn't installed.") - # input model - input_model = run_config.input_model - # Azure ML Client olive_config = run_config.to_json() engine = run_config.engine.create_engine(package_config, run_config.azureml_client, workflow_id) @@ -168,10 +164,7 @@ def run_engine(package_config: OlivePackageConfig, run_config: RunConfig): and run_config.auto_optimizer_config is not None and not run_config.auto_optimizer_config.disable_auto_optimizer ) - if auto_optimizer_enabled: - is_ep_required = True - else: - is_ep_required = is_execution_provider_required(run_config, package_config) + is_ep_required = auto_optimizer_enabled or is_execution_provider_required(run_config, package_config) # Register passes since we need to know whether they need to run on target used_passes = list(get_used_passes(run_config)) @@ -196,66 +189,32 @@ def run_engine(package_config: OlivePackageConfig, run_config: RunConfig): engine.target_config, skip_supported_eps_check=target_not_used, is_ep_required=is_ep_required ) - pass_list = [] - acc_list = [] - if auto_optimizer_enabled: - # For auto optimizer, Olive generates passes and pass_flows for each accelerator - # that means, the passes and pass_flows might be different for each accelerator - for acc_spec in accelerator_specs: - _passes, pass_flows = AutoOptimizer( - input_model, - engine.evaluator_config, - acc_spec, - run_config.auto_optimizer_config, - run_config.data_configs, - ).suggest() - pass_list.append(({k: RunPassConfig.parse_obj(v) for k, v in _passes.items()}, pass_flows)) - acc_list.append([acc_spec]) - else: - # For non-auto-optimizer case, Olive uses the same passes and pass_flows for all accelerators - # if user needs different passes and pass_flows for each accelerator, they need to write multiple - # config files. - pass_list.append((run_config.passes, run_config.pass_flows)) - acc_list.append(accelerator_specs) - - run_rls = {} - # Note that, in Olive, there are two positions where the accelerator_specs are looped over: - # 1. olive workflow run level: this is where the accelerator_specs are created and passed to - # the engine. In this level, accelerator specs can be used to generate passes and pass_flows. - # 2. engine level: this is where the accelerator_specs are looped over to run the passes. - # TODO(anyone): refactor the code to remove the engine level loop if possible. - # For time being, we are keeping both loops, but in future, we might want to refactor the code - # to remove engine level loop and pass the accelerator_specs to the engine directly. - for accelerator_spec, (passes, pass_flows) in zip(acc_list, pass_list): - engine.reset_passes() - pass_flows_to_run = {p for ps in pass_flows for p in ps} if pass_flows else set(passes.keys()) - - # Initializes the pass and register it with the engine - for pass_name, pass_config in passes.items(): - if pass_name in pass_flows_to_run: - host = pass_config.host.create_system() if pass_config.host is not None else None - engine.register( - pass_config.type, - config=pass_config.config, - name=pass_name, - host=host, - evaluator_config=pass_config.evaluator, - ) - engine.set_pass_flows(pass_flows) - - # run - run_rls.update( - engine.run( - input_model, - accelerator_spec, - run_config.engine.packaging_config, - run_config.engine.output_dir, - run_config.engine.evaluate_input_model, - run_config.engine.log_to_file, - run_config.engine.log_severity_level, - ) + # Initializes the passes and register it with the engine + passes_to_run = ( + {pass_name for pass_flow in run_config.pass_flows for pass_name in pass_flow} + if run_config.pass_flows + else set(run_config.passes.keys()) + ) + for pass_name in passes_to_run: + pass_run_config = run_config.passes[pass_name] + engine.register( + pass_run_config.type, + config=pass_run_config.config, + name=pass_name, + host=pass_run_config.host.create_system() if pass_run_config.host is not None else None, + evaluator_config=pass_run_config.evaluator, ) - return run_rls + engine.set_pass_flows(run_config.pass_flows or [list(run_config.passes.keys())]) + + return engine.run( + run_config.input_model, + accelerator_specs, + run_config.engine.packaging_config, + run_config.engine.output_dir, + run_config.engine.evaluate_input_model, + run_config.engine.log_to_file, + run_config.engine.log_severity_level, + ) def set_olive_config_for_aml_system(olive_config: dict): diff --git a/test/unit_test/engine/test_engine.py b/test/unit_test/engine/test_engine.py index b228a913f..0e4e82cff 100644 --- a/test/unit_test/engine/test_engine.py +++ b/test/unit_test/engine/test_engine.py @@ -62,10 +62,10 @@ def test_register(self, tmpdir): engine.register(OnnxConversion, host=system, evaluator_config=evaluator_config) # assert - assert name in engine.pass_config - assert engine.pass_config[name]["type"] == OnnxConversion - assert engine.pass_config[name]["host"] == system - assert engine.pass_config[name]["evaluator"] == evaluator_config + assert name in engine.pass_run_configs + assert engine.pass_run_configs[name]["type"] == OnnxConversion.__name__ + assert engine.pass_run_configs[name]["host"] == system + assert engine.pass_run_configs[name]["evaluator"] == evaluator_config def test_register_no_search(self, tmpdir): # setup @@ -82,7 +82,7 @@ def test_register_no_search(self, tmpdir): engine.register(OnnxDynamicQuantization) # assert - assert "OnnxDynamicQuantization" in engine.pass_config + assert "OnnxDynamicQuantization" in engine.pass_run_configs def test_default_engine_run(self, tmpdir): # setup @@ -169,11 +169,7 @@ def test_run(self, mock_local_system, tmp_path): # execute output_dir = Path(tmp_path) - actual_res = engine.run( - model_config, - [DEFAULT_CPU_ACCELERATOR], - output_dir=output_dir, - ) + actual_res = engine.run(model_config, [DEFAULT_CPU_ACCELERATOR], output_dir=output_dir) actual_res = actual_res[DEFAULT_CPU_ACCELERATOR] # make sure the input model always be in engine.footprints diff --git a/test/unit_test/passes/common/test_user_script.py b/test/unit_test/passes/common/test_user_script.py index 9e0fd2d38..d21fe9934 100644 --- a/test/unit_test/passes/common/test_user_script.py +++ b/test/unit_test/passes/common/test_user_script.py @@ -9,5 +9,5 @@ class TestUserScriptConfig: def test_no_config(self): config = {} - config = OrtSessionParamsTuning.generate_search_space(DEFAULT_CPU_ACCELERATOR, config, True) + config = OrtSessionParamsTuning.generate_config(DEFAULT_CPU_ACCELERATOR, config, True) assert config diff --git a/test/unit_test/passes/onnx/test_bnb_quantization.py b/test/unit_test/passes/onnx/test_bnb_quantization.py index 2edb24a0b..979e75f2d 100644 --- a/test/unit_test/passes/onnx/test_bnb_quantization.py +++ b/test/unit_test/passes/onnx/test_bnb_quantization.py @@ -78,9 +78,9 @@ def test_validate_quant_type(pass_config, model_attributes, expected_error, tmp_ p = create_pass_from_dict(OnnxBnb4Quantization, pass_config, disable_search=True) if expected_error: with pytest.raises(expected_error): - p.run(input_model, str(tmp_path / "model.onnx"), None) + p.run(input_model, str(tmp_path / "model.onnx")) else: - p.run(input_model, str(tmp_path / "model.onnx"), None) + p.run(input_model, str(tmp_path / "model.onnx")) @pytest.mark.parametrize(("model_generator", "expected_count"), [(get_onnx_matmul_model, 1), (get_onnx_gemm_model, 0)]) @@ -117,5 +117,5 @@ def count_matmulbnb4_nodes(model: onnx.ModelProto): def test_quantized_modules(tmp_path, model_generator, quantized_modules, expected_count): input_model = model_generator(str(tmp_path / "model.onnx")) p = create_pass_from_dict(OnnxBnb4Quantization, {"quant_type": "nf4", "quantized_modules": quantized_modules}) - output_model = p.run(input_model, (tmp_path / "output_model.onnx"), None) + output_model = p.run(input_model, (tmp_path / "output_model.onnx")) assert count_matmulbnb4_nodes(output_model.load_model()) == expected_count diff --git a/test/unit_test/passes/onnx/test_optimum_conversion.py b/test/unit_test/passes/onnx/test_optimum_conversion.py index 07dcb2dc7..aecda404b 100644 --- a/test/unit_test/passes/onnx/test_optimum_conversion.py +++ b/test/unit_test/passes/onnx/test_optimum_conversion.py @@ -86,13 +86,13 @@ def test_optimum_configs(config, is_valid, tmp_path): output_folder = tmp_path if not is_valid: - assert p.validate_search_point(config, None) is False + assert p.validate_config(config, None) is False with pytest.raises( ValueError, match="FP16 export is supported only when exporting on GPU. Please pass the option `--device cuda`.", ): p.run(input_model, output_folder) else: - assert p.validate_search_point(config, None) + assert p.validate_config(config, None) onnx_model = p.run(input_model, output_folder) assert Path(onnx_model.model_path).exists() diff --git a/test/unit_test/passes/onnx/test_transformer_optimization.py b/test/unit_test/passes/onnx/test_transformer_optimization.py index 5f1da244d..52a7c6819 100644 --- a/test/unit_test/passes/onnx/test_transformer_optimization.py +++ b/test/unit_test/passes/onnx/test_transformer_optimization.py @@ -21,7 +21,7 @@ def test_fusion_options(): config = {"model_type": "bart", "optimization_options": {"use_multi_head_attention": True}} - config = OrtTransformersOptimization.generate_search_space(DEFAULT_CPU_ACCELERATOR, config, True) + config = OrtTransformersOptimization.generate_config(DEFAULT_CPU_ACCELERATOR, config, disable_search=True) transformer_optimization = OrtTransformersOptimization(DEFAULT_CPU_ACCELERATOR, config, True) run_config = deepcopy(config) del ( @@ -47,7 +47,7 @@ def test_ort_transformer_optimization_pass(tmp_path): input_model = get_onnx_model() config = {"model_type": "bert"} - config = OrtTransformersOptimization.generate_search_space(DEFAULT_CPU_ACCELERATOR, config, disable_search=True) + config = OrtTransformersOptimization.generate_config(DEFAULT_CPU_ACCELERATOR, config, disable_search=True) p = OrtTransformersOptimization(DEFAULT_CPU_ACCELERATOR, config, True) output_folder = str(tmp_path / "onnx") @@ -74,9 +74,9 @@ def test_invalid_ep_config(use_gpu, fp16, accelerator_spec, mock_inferece_sessio input_model = get_onnx_model() config = {"model_type": "bert", "use_gpu": use_gpu, "float16": fp16} - config = OrtTransformersOptimization.generate_search_space(accelerator_spec, config, disable_search=True) + config = OrtTransformersOptimization.generate_config(accelerator_spec, config, disable_search=True) p = OrtTransformersOptimization(accelerator_spec, config, True) - is_pruned = not p.validate_search_point(config, accelerator_spec) + is_pruned = not p.validate_config(config, accelerator_spec, disable_search=True) if accelerator_spec.execution_provider == "CPUExecutionProvider": if fp16 and use_gpu: @@ -143,7 +143,7 @@ def test_transformer_optimization_invalid_model_type(tmp_path): input_model = get_onnx_model() config = {"model_type": None} - config = OrtTransformersOptimization.generate_search_space(DEFAULT_CPU_ACCELERATOR, config, disable_search=True) + config = OrtTransformersOptimization.generate_config(DEFAULT_CPU_ACCELERATOR, config, disable_search=True) p = OrtTransformersOptimization(DEFAULT_CPU_ACCELERATOR, config, True) output_folder = str(tmp_path / "onnx") @@ -161,7 +161,7 @@ def test_optimization_with_provider(mock_proto_to_model, mock_optimize_model, tm config = {"model_type": "bert", "use_gpu": True} dml_ep = AcceleratorSpec(accelerator_type=Device.GPU, execution_provider="DmlExecutionProvider") - config = OrtTransformersOptimization.generate_search_space(dml_ep, config, disable_search=True) + config = OrtTransformersOptimization.generate_config(dml_ep, config, disable_search=True) p = OrtTransformersOptimization(dml_ep, config, True) output_folder = str(tmp_path / "onnx") diff --git a/test/unit_test/passes/test_pass_serialization.py b/test/unit_test/passes/test_pass_serialization.py index 21214734e..88e09906e 100644 --- a/test/unit_test/passes/test_pass_serialization.py +++ b/test/unit_test/passes/test_pass_serialization.py @@ -11,8 +11,8 @@ @pytest.mark.parametrize("host_device", [None, "cpu", "gpu"]) def test_pass_serialization(host_device): - onnx_conversion_config = {} - config = OnnxConversion.generate_search_space(DEFAULT_CPU_ACCELERATOR, onnx_conversion_config) + config = {} + config = OnnxConversion.generate_config(DEFAULT_CPU_ACCELERATOR, config) onnx_conversion = OnnxConversion(DEFAULT_CPU_ACCELERATOR, config, host_device=host_device) json = onnx_conversion.to_json(True) diff --git a/test/unit_test/systems/docker/test_docker_system.py b/test/unit_test/systems/docker/test_docker_system.py index 7c8613ae3..b1eaca8dd 100644 --- a/test/unit_test/systems/docker/test_docker_system.py +++ b/test/unit_test/systems/docker/test_docker_system.py @@ -277,8 +277,6 @@ def test_runner_entry(self, tmp_path): p = create_pass_from_dict(OrtSessionParamsTuning, {}, disable_search=True) pass_config = p.to_json(check_object=True) - config = p.config_at_search_point({}) - pass_config["config"].update(p.serialize_config(config, check_object=True)) onnx_model = get_onnx_model_config() diff --git a/test/unit_test/systems/python_environment/test_python_environment_system.py b/test/unit_test/systems/python_environment/test_python_environment_system.py index 5d92c5299..04efb6233 100644 --- a/test/unit_test/systems/python_environment/test_python_environment_system.py +++ b/test/unit_test/systems/python_environment/test_python_environment_system.py @@ -151,13 +151,10 @@ def test_run_pass(self, mock_model_config_parse_obj, mock__run_command): "dummy_param_2": "dummy_param_2_value", }, } - full_config = { - "dummy_param_1": "dummy_param_1_value", - "dummy_param_2": "dummy_param_2_value2", - } - expected_pass_config = {"type": "DummyPass", "config": full_config} + dummy_config = dummy_pass_config["config"] + expected_pass_config = {"type": "DummyPass", "config": dummy_config} the_pass.to_json.return_value = dummy_pass_config - the_pass.serialize_config.return_value = full_config + the_pass.serialize_config.return_value = dummy_config # mock return value mock_return_value = {"dummy_output_model_key": "dummy_output_model_value"} @@ -233,9 +230,8 @@ def test_pass_runner(self, mock_conversion_run, tmp_path): # create pass_config.json the_pass = get_onnxconversion_pass() - config = the_pass.config_at_search_point({}) pass_config = the_pass.to_json(check_object=True) - pass_config["config"].update(the_pass.serialize_config(config, check_object=True)) + with (tmp_path / "pass_config.json").open("w") as f: json.dump(pass_config, f) diff --git a/test/unit_test/systems/test_local.py b/test/unit_test/systems/test_local.py index 91eedbc77..cde0f2918 100644 --- a/test/unit_test/systems/test_local.py +++ b/test/unit_test/systems/test_local.py @@ -36,7 +36,7 @@ def test_run_pass(self): self.system.run_pass(p, olive_model, output_model_path) # assert - p.run.assert_called_once_with(olive_model.create_model(), output_model_path, None) + p.run.assert_called_once_with(olive_model.create_model(), output_model_path) METRIC_TEST_CASE: ClassVar[List[Metric]] = [ (partial(get_accuracy_metric, AccuracySubType.ACCURACY_SCORE)), diff --git a/test/unit_test/utils.py b/test/unit_test/utils.py index 551a26144..f3f670c08 100644 --- a/test/unit_test/utils.py +++ b/test/unit_test/utils.py @@ -240,11 +240,7 @@ def get_onnxconversion_pass(ignore_pass_config=True, target_opset=13): onnx_conversion_config = {"target_opset": target_opset} p = create_pass_from_dict(OnnxConversion, onnx_conversion_config) - if ignore_pass_config: - return p - pass_config = p.config_at_search_point({}) - pass_config = p.serialize_config(pass_config) - return p, pass_config + return p if ignore_pass_config else (p, p.to_json(check_object=True)["config"]) def get_onnx_dynamic_quantization_pass(disable_search=False):