Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Initialized Model Export #2224

Merged
merged 7 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def export(
source_path: Union[Path, str] = None,
target_path: Union[Path, str, None] = None,
model: Optional["torch.nn.Module"] = None, # noqa F401
tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa F401
onnx_model_name: str = ONNX_MODEL_NAME,
deployment_target: str = "deepsparse",
opset: Optional[int] = None,
Expand Down Expand Up @@ -134,11 +135,9 @@ def export(
will default to source_path
:param model: The PyTorch model to export. If provided, the source_path
should be set to None to avoid potential confusion and entaglement
of sources. This means that, the full
export logic will not be enforced (e.g. the final deployment directory
will not be complete, it will not be possible to run validate_structure
method or apply some optimizations that require complete deployment
directory structure)
of sources
:param tokenizer: An optional tokenizer to export if passing in a source through
the model argument. This argument takes no effect if a source_path is provided
:param onnx_model_name: The name of the exported model.
Defaults to ONNX_MODEL_NAME.
:param deployment_target: The deployment target to export
Expand Down Expand Up @@ -184,6 +183,7 @@ def export(
from sparseml.export.validators import validate_structure as validate_structure_
from sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
remove_past_key_value_support_from_config,
resolve_integration,
)
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
Expand All @@ -206,8 +206,18 @@ def export(
source_path = process_source_path(source_path)
if target_path is None:
target_path = source_path
if tokenizer is not None:
_LOGGER.warning(
"Passed a tokenizer is not supported when exporting from ",
"a source path. The tokenizer will be ignored. ",
)

if model is not None and hasattr(model, "config"):
model.config = remove_past_key_value_support_from_config(model.config)
Satrat marked this conversation as resolved.
Show resolved Hide resolved

integration = resolve_integration(source_path, integration)
integration = resolve_integration(
source_path=source_path, source_model=model, integration=integration
)
_LOGGER.info(f"Starting export for {integration} model...")

if target_path is None:
Expand Down Expand Up @@ -262,6 +272,8 @@ def export(
session_manager.active_session().reset()

_LOGGER.info("Creating data loader for the export...")
if tokenizer is not None:
loaded_model_kwargs["tokenizer"] = tokenizer
data_loader, loaded_data_loader_kwargs = helper_functions.create_data_loader(
model=model,
task=task,
Expand Down Expand Up @@ -323,6 +335,8 @@ def export(

deployment_folder_dir = create_deployment_folder(
source_path=source_path,
source_config=getattr(model, "config", None),
source_tokenizer=tokenizer,
target_path=target_path,
deployment_directory_name=deployment_directory_name,
deployment_directory_files_mandatory=helper_functions.deployment_directory_files_mandatory, # noqa: E501
Expand Down
12 changes: 11 additions & 1 deletion src/sparseml/export/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def create_deployment_folder(
target_path: Union[Path, str],
deployment_directory_files_mandatory: List[str],
source_path: Union[Path, str, None] = None,
source_config: Optional["PreTrainedConfig"] = None, # noqa F401
source_tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa F401
deployment_directory_files_optional: Optional[List[str]] = None,
deployment_directory_name: str = "deployment",
onnx_model_name: Optional[str] = None,
Expand All @@ -135,6 +137,8 @@ def create_deployment_folder(
The files will be copied to target_path/deployment_directory_name.
:param source_path: The path to the source folder (where the original model
files are stored)
:param source_config: Optional Hugging Face config to copy to deployment dir
:param source_tokenizer: Optional Hugging Face tokenizer to copy to deployment dir
:param deployment_directory_files_mandatory: The mandatory list of files
to copy to the deployment directory. If the file is an ONNX model
(or ONNX data file), the file will be copied from target_path.
Expand All @@ -161,10 +165,16 @@ def create_deployment_folder(
deployment_folder_dir=deployment_folder_dir,
onnx_model_name=onnx_model_name,
)

if source_path is None:
# exporting an instantiated model
if source_config is not None:
source_config.save_pretrained(deployment_folder_dir)
if source_tokenizer is not None:
source_tokenizer.save_pretrained(deployment_folder_dir)
return deployment_folder_dir

# copy the relevant files from source_path
# exporting from a source path, copy the relevant files to deployment directory
for file_name in deployment_directory_files_mandatory:
copy_mandatory_deployment_files(
file_name, source_path, target_path, onnx_model_name, deployment_folder_dir
Expand Down
24 changes: 23 additions & 1 deletion src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,29 @@ class Integrations(Enum):

def resolve_integration(
source_path: Union[Path, str, None] = None,
source_model: Optional["PreTrainedModel"] = None, # noqa F401
integration: Optional[str] = None,
) -> str:
"""
Resolve the integration to use.

If integration is not provided, attempt to infer it from the source_path.
If integration is not provided, attempt to infer it from the source_path or model.
Once the integration is resolved, perform the hot import to register
the integration helper functions.

:param source_path: The path to the PyTorch model to export.
:param source_model: An instantiated model to export
:param integration: Optional name of the integration to use. If not provided,
will attempt to infer it from the source_path.
:return: The name of the integration to use for exporting the model.
"""

integration = integration or _infer_integration_from_source_path(source_path)

# attempt to infer transformers based on model attribute
if source_model is not None and hasattr(source_model, "config_class"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why config_class is a deciding attribute here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really just to tell if its a PreTrainedModel, but I didn't want to have to add the transformers dependency to this part of the repo

integration = Integrations.transformers.value

if integration == Integrations.image_classification.value:
import sparseml.pytorch.image_classification.integration_helper_functions # noqa F401

Expand All @@ -72,6 +78,22 @@ def resolve_integration(
)


def remove_past_key_value_support_from_config(config):
"""
Modify config of the causal language model so that it turns off the
past key value support. This means that the model initialized from
this config will not take past key values as input and will not output
past key values.
"""
# not take past_key_values as input
config.is_decoder = True
# whether to use past key values an input
config.use_past = False
# whether to output past key values
config.use_cache = False
return config


def _infer_integration_from_source_path(
source_path: Union[Path, str, None] = None
) -> Optional[str]:
Expand Down
3 changes: 2 additions & 1 deletion src/sparseml/pytorch/torch_to_onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def remove_leftover_files(self):
torch_onnx_export_transform, _TorchOnnxExport
), "Expected the first transform from self.transform to be _TorchOnnxExport"
for file in torch_onnx_export_transform.leftover_files:
os.remove(file)
if os.path.exists(file):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why is this change needed now?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I am not quite sure. I would occasionally get an error about deleting a file that didn't exist when exporting and this fixed it

os.remove(file)


class _TorchOnnxExport(BaseTransform):
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/transformers/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Integrations,
remove_past_key_value_support_from_config,
)
from sparseml.transformers.finetune.data.data_helpers import format_calibration_data
from sparseml.transformers.utils.helpers import (
Expand All @@ -34,7 +35,6 @@
OPTIONAL_DEPLOYMENT_FILES,
TaskNames,
create_fake_dataloader,
remove_past_key_value_support_from_config,
resolve_sequence_length,
)
from sparseml.transformers.utils.initializers import (
Expand Down Expand Up @@ -115,7 +115,7 @@ def create_data_loader(
data_args: Optional[Dict[str, Any]] = None,
config: Optional["AutoConfig"] = None, # noqa F821
source_path: Optional[str] = None,
sequence_length: Optional[int] = None,
sequence_length: int = 384,
tokenizer: Optional["AutoTokenizer"] = None, # noqa F821
dataset_with_labels: bool = False,
**kwargs,
Expand Down
16 changes: 0 additions & 16 deletions src/sparseml/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,6 @@ class TaskNames(Enum):
}


def remove_past_key_value_support_from_config(config: AutoConfig) -> AutoConfig:
"""
Modify config of the causal language model so that it turns off the
past key value support. This means that the model initialized from
this config will not take past key values as input and will not output
past key values.
"""
# not take past_key_values as input
config.is_decoder = True
# whether to use past key values an input
config.use_past = False
# whether to output past key values
config.use_cache = False
return config


def is_transformer_model(source_path: Union[Path, str]) -> bool:
"""
:param source_path: The path to the model
Expand Down
23 changes: 5 additions & 18 deletions tests/sparseml/export/transformers/test_generative_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@

from huggingface_hub import snapshot_download
from sparseml import export
from sparseml.transformers import SparseAutoConfig, SparseAutoModelForCausalLM
from sparseml.transformers.utils.helpers import (
remove_past_key_value_support_from_config,
)
from sparseml.transformers import SparseAutoModelForCausalLM, SparseAutoTokenizer


@pytest.mark.parametrize(
Expand All @@ -49,21 +46,11 @@ def test_export_initialized_model_no_source_path(self, setup):
# export the transformer model, that is being passed to the
# `export` API directly as an object
source_path, target_path, task = setup
config = remove_past_key_value_support_from_config(
SparseAutoConfig.from_pretrained(source_path)
)
export(
model=SparseAutoModelForCausalLM.from_pretrained(
source_path, config=config
),
model=SparseAutoModelForCausalLM.from_pretrained(source_path),
tokenizer=SparseAutoTokenizer.from_pretrained(source_path),
target_path=target_path,
integration="transformers",
sequence_length=384,
# we need to disable applying kv cache injection
# because the script does not have access to the
# config.json (we are not creating a full deployment
# directory during the export)
graph_optimizations="none",
Satrat marked this conversation as resolved.
Show resolved Hide resolved
task=task,
validate_correctness=True,
num_export_samples=2,
Expand All @@ -73,11 +60,11 @@ def test_export_initialized_model_no_source_path(self, setup):
)
assert (target_path / "deployment" / "model.onnx").exists()
assert not (target_path / "deployment" / "model.data").exists()
# assert that kv cache injection has not been applied
# check if kv cache injection has been applied
onnx_model = onnx.load(
str(target_path / "deployment" / "model.onnx"), load_external_data=False
)
assert not any(
assert any(
inp.name == "past_key_values.0.key" for inp in onnx_model.graph.input
)

Expand Down
Loading