-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
749e27c
commit d442b66
Showing
2 changed files
with
194 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
General utilities for exporting models to different formats using safe tensors. | ||
""" | ||
|
||
import logging | ||
from pathlib import Path | ||
from typing import Any, Dict, Literal, Optional, Union | ||
|
||
from transformers import PreTrainedModel, PreTrainedTokenizerBase | ||
|
||
from sparseml.utils import get_unique_dir_name | ||
|
||
|
||
__all__ = [ | ||
"export_safetensors", | ||
"SUPPORTED_FORMAT_TYPES", | ||
] | ||
|
||
SUPPORTED_FORMAT_TYPES = Literal["exllama", "marlin"] | ||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
def export_safetensors( | ||
model: PreTrainedModel, | ||
tokenizer: Optional[PreTrainedTokenizerBase] = None, | ||
format: SUPPORTED_FORMAT_TYPES = "exllama", | ||
save_dir: Union[str, Path, None] = None, | ||
): | ||
""" | ||
A utility function to export a model to safetensor format. | ||
Calls the appropriate state dict translation function based on the format | ||
and saves the translated state dict to the specified directory. | ||
If the directory is defaults to cwd/exported_model. If the directory already exists, | ||
a new directory is created with a unique name. | ||
:param model: The loaded model to be exported. | ||
:param tokenizer: The tokenizer associated with the model. | ||
:param format: The format to which the model should be exported. | ||
Default is "exllama". | ||
:param save_dir: The directory where the model should be saved. | ||
""" | ||
# original state dict and config | ||
state_dict: Dict[Any, Any] = model.state_dict() | ||
config: Dict[str, Any] = model.config | ||
|
||
# translation | ||
_LOGGER.info(f"Translating state dict and config to {format} format.") | ||
translated_state_dict: Dict[Any, Any] = translate_state_dict( | ||
state_dict=state_dict, format=format | ||
) | ||
translated_config: Any = translate_config(config=config, format=format) | ||
_LOGGER.info(f"Translation to {format} format complete.") | ||
|
||
if save_dir is None: | ||
save_dir = Path.cwd() / "exported_model" | ||
|
||
save_dir: str = get_unique_dir_name(dir_name=save_dir) | ||
|
||
# saving | ||
_save_state_dict_safetensor( | ||
state_dict=translated_state_dict, name="model.safetensors", parent_dir=save_dir | ||
) | ||
_save_tokenizer(tokenizer=tokenizer, parent_dir=save_dir) | ||
_save_config(model=translated_config, parent_dir=save_dir) | ||
|
||
_LOGGER.info( | ||
f"Model and it's artifacts exported to {format} format and saved to {save_dir}." | ||
) | ||
|
||
|
||
def translate_state_dict( | ||
state_dict: Dict[Any, Any], format: SUPPORTED_FORMAT_TYPES | ||
) -> Dict[Any, Any]: | ||
""" | ||
A utility function to translate the state dict to the specified format. | ||
:param state_dict: The state dict to be translated. | ||
:param format: The format to which the state dict should be translated. | ||
""" | ||
if format == "marlin": | ||
return _translate_state_dict_marlin(state_dict=state_dict) | ||
if format == "exllama": | ||
return _translate_state_dict_exllama(state_dict=state_dict) | ||
|
||
|
||
def translate_config(config: Any, format: SUPPORTED_FORMAT_TYPES) -> Any: | ||
""" | ||
A utility function to translate the config to the specified format. | ||
:param config: The config to be translated. | ||
:param format: The format to which the config should be translated. | ||
""" | ||
if format == "marlin": | ||
return _translate_config_marlin(config=config) | ||
if format == "exllama": | ||
return _translate_config_exllama(config=config) | ||
|
||
|
||
def _translate_state_dict_marlin(state_dict: Dict[Any, Any]) -> Dict[Any, Any]: | ||
""" | ||
Translate the state dict to the Marlin format. | ||
""" | ||
raise NotImplementedError( | ||
"Translating state dict to Marlin format is not yet supported." | ||
) | ||
|
||
|
||
def _translate_config_marlin(config: Any) -> Any: | ||
""" | ||
Translate the config to the Marlin format. | ||
""" | ||
raise NotImplementedError( | ||
"Translating config to Marlin format is not yet supported." | ||
) | ||
|
||
|
||
def _translate_state_dict_exllama(state_dict: Dict[Any, Any]) -> Dict[Any, Any]: | ||
""" | ||
Translate the state dict to the Exllama format. | ||
""" | ||
raise NotImplementedError( | ||
"Translating state dict to Exllama format is not yet supported." | ||
) | ||
|
||
|
||
def _translate_config_exllama(config: Any) -> Any: | ||
""" | ||
Translate the config to the Exllama format. | ||
""" | ||
raise NotImplementedError( | ||
"Translating config to Exllama format is not yet supported." | ||
) | ||
|
||
|
||
def _save_state_dict_safetensor( | ||
state_dict: Dict[Any, Any], name: str, parent_dir: Union[str, Path] | ||
): | ||
""" | ||
Save the state dict to a safe tensor file. | ||
""" | ||
raise NotImplementedError( | ||
"Saving state dict to safe tesnor file is not yet supported." | ||
) | ||
|
||
|
||
def _save_tokenizer( | ||
tokenizer: Optional[PreTrainedTokenizerBase], parent_dir: Union[str, Path] | ||
): | ||
""" | ||
Save the tokenizer to a file. | ||
""" | ||
raise NotImplementedError("Saving tokenizer to a file is not yet supported.") | ||
|
||
|
||
def _save_config(config: Any, parent_dir: Union[str, Path]): | ||
""" | ||
Save the config to a file. | ||
""" | ||
raise NotImplementedError("Saving config to a file is not yet supported.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters