Skip to content

Commit

Permalink
Add Translation Structure
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Mar 18, 2024
1 parent 749e27c commit d442b66
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 0 deletions.
173 changes: 173 additions & 0 deletions src/sparseml/transformers/utils/export_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
General utilities for exporting models to different formats using safe tensors.
"""

import logging
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union

from transformers import PreTrainedModel, PreTrainedTokenizerBase

from sparseml.utils import get_unique_dir_name


__all__ = [
"export_safetensors",
"SUPPORTED_FORMAT_TYPES",
]

SUPPORTED_FORMAT_TYPES = Literal["exllama", "marlin"]
_LOGGER = logging.getLogger(__name__)


def export_safetensors(
model: PreTrainedModel,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
format: SUPPORTED_FORMAT_TYPES = "exllama",
save_dir: Union[str, Path, None] = None,
):
"""
A utility function to export a model to safetensor format.
Calls the appropriate state dict translation function based on the format
and saves the translated state dict to the specified directory.
If the directory is defaults to cwd/exported_model. If the directory already exists,
a new directory is created with a unique name.
:param model: The loaded model to be exported.
:param tokenizer: The tokenizer associated with the model.
:param format: The format to which the model should be exported.
Default is "exllama".
:param save_dir: The directory where the model should be saved.
"""
# original state dict and config
state_dict: Dict[Any, Any] = model.state_dict()
config: Dict[str, Any] = model.config

# translation
_LOGGER.info(f"Translating state dict and config to {format} format.")
translated_state_dict: Dict[Any, Any] = translate_state_dict(
state_dict=state_dict, format=format
)
translated_config: Any = translate_config(config=config, format=format)
_LOGGER.info(f"Translation to {format} format complete.")

if save_dir is None:
save_dir = Path.cwd() / "exported_model"

save_dir: str = get_unique_dir_name(dir_name=save_dir)

# saving
_save_state_dict_safetensor(
state_dict=translated_state_dict, name="model.safetensors", parent_dir=save_dir
)
_save_tokenizer(tokenizer=tokenizer, parent_dir=save_dir)
_save_config(model=translated_config, parent_dir=save_dir)

_LOGGER.info(
f"Model and it's artifacts exported to {format} format and saved to {save_dir}."
)


def translate_state_dict(
state_dict: Dict[Any, Any], format: SUPPORTED_FORMAT_TYPES
) -> Dict[Any, Any]:
"""
A utility function to translate the state dict to the specified format.
:param state_dict: The state dict to be translated.
:param format: The format to which the state dict should be translated.
"""
if format == "marlin":
return _translate_state_dict_marlin(state_dict=state_dict)
if format == "exllama":
return _translate_state_dict_exllama(state_dict=state_dict)


def translate_config(config: Any, format: SUPPORTED_FORMAT_TYPES) -> Any:
"""
A utility function to translate the config to the specified format.
:param config: The config to be translated.
:param format: The format to which the config should be translated.
"""
if format == "marlin":
return _translate_config_marlin(config=config)
if format == "exllama":
return _translate_config_exllama(config=config)


def _translate_state_dict_marlin(state_dict: Dict[Any, Any]) -> Dict[Any, Any]:
"""
Translate the state dict to the Marlin format.
"""
raise NotImplementedError(
"Translating state dict to Marlin format is not yet supported."
)


def _translate_config_marlin(config: Any) -> Any:
"""
Translate the config to the Marlin format.
"""
raise NotImplementedError(
"Translating config to Marlin format is not yet supported."
)


def _translate_state_dict_exllama(state_dict: Dict[Any, Any]) -> Dict[Any, Any]:
"""
Translate the state dict to the Exllama format.
"""
raise NotImplementedError(
"Translating state dict to Exllama format is not yet supported."
)


def _translate_config_exllama(config: Any) -> Any:
"""
Translate the config to the Exllama format.
"""
raise NotImplementedError(
"Translating config to Exllama format is not yet supported."
)


def _save_state_dict_safetensor(
state_dict: Dict[Any, Any], name: str, parent_dir: Union[str, Path]
):
"""
Save the state dict to a safe tensor file.
"""
raise NotImplementedError(
"Saving state dict to safe tesnor file is not yet supported."
)


def _save_tokenizer(
tokenizer: Optional[PreTrainedTokenizerBase], parent_dir: Union[str, Path]
):
"""
Save the tokenizer to a file.
"""
raise NotImplementedError("Saving tokenizer to a file is not yet supported.")


def _save_config(config: Any, parent_dir: Union[str, Path]):
"""
Save the config to a file.
"""
raise NotImplementedError("Saving config to a file is not yet supported.")
21 changes: 21 additions & 0 deletions src/sparseml/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"parse_kwarg_tuples",
"download_zoo_training_dir",
"is_package_available",
"get_unique_dir_name",
]


Expand Down Expand Up @@ -974,3 +975,23 @@ def is_package_available(
return package_exists, package_version
else:
return package_exists


def get_unique_dir_name(dir_name: Union[str, Path]) -> str:
"""
A utility function to get a unique directory name by appending
a number to the directory name if the directory already exists
(Note: the function does not create the directory, it only
returns the unique directory name)
:param dir_name: The directory name to get a unique name for
:return: The unique directory name
"""
dir_name: str = str(dir_name)
counter: str = 1
new_dir_name: str = dir_name

while Path(new_dir_name).exists():
new_dir_name = f"{dir_name}_{counter}"
counter += 1
return new_dir_name

0 comments on commit d442b66

Please sign in to comment.