From d442b668195797f6b6f5bd2978f24dd5cd4e1fc0 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Tue, 5 Mar 2024 21:13:04 +0000 Subject: [PATCH] Add Translation Structure --- .../transformers/utils/export_helpers.py | 173 ++++++++++++++++++ src/sparseml/utils/helpers.py | 21 +++ 2 files changed, 194 insertions(+) create mode 100644 src/sparseml/transformers/utils/export_helpers.py diff --git a/src/sparseml/transformers/utils/export_helpers.py b/src/sparseml/transformers/utils/export_helpers.py new file mode 100644 index 00000000000..356fa078432 --- /dev/null +++ b/src/sparseml/transformers/utils/export_helpers.py @@ -0,0 +1,173 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +General utilities for exporting models to different formats using safe tensors. +""" + +import logging +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Union + +from transformers import PreTrainedModel, PreTrainedTokenizerBase + +from sparseml.utils import get_unique_dir_name + + +__all__ = [ + "export_safetensors", + "SUPPORTED_FORMAT_TYPES", +] + +SUPPORTED_FORMAT_TYPES = Literal["exllama", "marlin"] +_LOGGER = logging.getLogger(__name__) + + +def export_safetensors( + model: PreTrainedModel, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + format: SUPPORTED_FORMAT_TYPES = "exllama", + save_dir: Union[str, Path, None] = None, +): + """ + A utility function to export a model to safetensor format. + Calls the appropriate state dict translation function based on the format + and saves the translated state dict to the specified directory. + If the directory is defaults to cwd/exported_model. If the directory already exists, + a new directory is created with a unique name. + + :param model: The loaded model to be exported. + :param tokenizer: The tokenizer associated with the model. + :param format: The format to which the model should be exported. + Default is "exllama". + :param save_dir: The directory where the model should be saved. + """ + # original state dict and config + state_dict: Dict[Any, Any] = model.state_dict() + config: Dict[str, Any] = model.config + + # translation + _LOGGER.info(f"Translating state dict and config to {format} format.") + translated_state_dict: Dict[Any, Any] = translate_state_dict( + state_dict=state_dict, format=format + ) + translated_config: Any = translate_config(config=config, format=format) + _LOGGER.info(f"Translation to {format} format complete.") + + if save_dir is None: + save_dir = Path.cwd() / "exported_model" + + save_dir: str = get_unique_dir_name(dir_name=save_dir) + + # saving + _save_state_dict_safetensor( + state_dict=translated_state_dict, name="model.safetensors", parent_dir=save_dir + ) + _save_tokenizer(tokenizer=tokenizer, parent_dir=save_dir) + _save_config(model=translated_config, parent_dir=save_dir) + + _LOGGER.info( + f"Model and it's artifacts exported to {format} format and saved to {save_dir}." + ) + + +def translate_state_dict( + state_dict: Dict[Any, Any], format: SUPPORTED_FORMAT_TYPES +) -> Dict[Any, Any]: + """ + A utility function to translate the state dict to the specified format. + + :param state_dict: The state dict to be translated. + :param format: The format to which the state dict should be translated. + """ + if format == "marlin": + return _translate_state_dict_marlin(state_dict=state_dict) + if format == "exllama": + return _translate_state_dict_exllama(state_dict=state_dict) + + +def translate_config(config: Any, format: SUPPORTED_FORMAT_TYPES) -> Any: + """ + A utility function to translate the config to the specified format. + + :param config: The config to be translated. + :param format: The format to which the config should be translated. + """ + if format == "marlin": + return _translate_config_marlin(config=config) + if format == "exllama": + return _translate_config_exllama(config=config) + + +def _translate_state_dict_marlin(state_dict: Dict[Any, Any]) -> Dict[Any, Any]: + """ + Translate the state dict to the Marlin format. + """ + raise NotImplementedError( + "Translating state dict to Marlin format is not yet supported." + ) + + +def _translate_config_marlin(config: Any) -> Any: + """ + Translate the config to the Marlin format. + """ + raise NotImplementedError( + "Translating config to Marlin format is not yet supported." + ) + + +def _translate_state_dict_exllama(state_dict: Dict[Any, Any]) -> Dict[Any, Any]: + """ + Translate the state dict to the Exllama format. + """ + raise NotImplementedError( + "Translating state dict to Exllama format is not yet supported." + ) + + +def _translate_config_exllama(config: Any) -> Any: + """ + Translate the config to the Exllama format. + """ + raise NotImplementedError( + "Translating config to Exllama format is not yet supported." + ) + + +def _save_state_dict_safetensor( + state_dict: Dict[Any, Any], name: str, parent_dir: Union[str, Path] +): + """ + Save the state dict to a safe tensor file. + """ + raise NotImplementedError( + "Saving state dict to safe tesnor file is not yet supported." + ) + + +def _save_tokenizer( + tokenizer: Optional[PreTrainedTokenizerBase], parent_dir: Union[str, Path] +): + """ + Save the tokenizer to a file. + """ + raise NotImplementedError("Saving tokenizer to a file is not yet supported.") + + +def _save_config(config: Any, parent_dir: Union[str, Path]): + """ + Save the config to a file. + """ + raise NotImplementedError("Saving config to a file is not yet supported.") diff --git a/src/sparseml/utils/helpers.py b/src/sparseml/utils/helpers.py index 6c1d4f3ad6c..0c5b4d8330e 100644 --- a/src/sparseml/utils/helpers.py +++ b/src/sparseml/utils/helpers.py @@ -74,6 +74,7 @@ "parse_kwarg_tuples", "download_zoo_training_dir", "is_package_available", + "get_unique_dir_name", ] @@ -974,3 +975,23 @@ def is_package_available( return package_exists, package_version else: return package_exists + + +def get_unique_dir_name(dir_name: Union[str, Path]) -> str: + """ + A utility function to get a unique directory name by appending + a number to the directory name if the directory already exists + (Note: the function does not create the directory, it only + returns the unique directory name) + + :param dir_name: The directory name to get a unique name for + :return: The unique directory name + """ + dir_name: str = str(dir_name) + counter: str = 1 + new_dir_name: str = dir_name + + while Path(new_dir_name).exists(): + new_dir_name = f"{dir_name}_{counter}" + counter += 1 + return new_dir_name