diff --git a/src/sparseml/pytorch/model_load/helpers.py b/src/sparseml/pytorch/model_load/helpers.py index b4e94700866..9016583ddf3 100644 --- a/src/sparseml/pytorch/model_load/helpers.py +++ b/src/sparseml/pytorch/model_load/helpers.py @@ -233,6 +233,7 @@ def save_model_and_recipe( save_path: str, tokenizer: Optional[Any] = None, save_safetensors: bool = False, + save_compressed: bool = False, ): """ Save a model, tokenizer and the currently loaded recipe to file @@ -241,9 +242,12 @@ def save_model_and_recipe( :param save_path: path to save output to :param tokenizer: model tokenizer to save :param save_safetensors: whether to save as safetensors or pickle (bin) + :param save_compressed: whether to compress sparse weights on disk """ - model.save_pretrained(save_path, safe_serialization=save_safetensors) + model.save_pretrained( + save_path, save_compressed=save_compressed, safe_serialization=save_safetensors + ) if tokenizer is not None: tokenizer.save_pretrained(save_path) diff --git a/src/sparseml/transformers/compression/README.md b/src/sparseml/transformers/compression/README.md index 771e77d0891..0b4a20579f6 100644 --- a/src/sparseml/transformers/compression/README.md +++ b/src/sparseml/transformers/compression/README.md @@ -61,17 +61,16 @@ model = SparseAutoModelForCausalLM.from_pretrained( ``` Saving a compressed model with an explicitly provided compression config. The config -is saved to the model's `config.json` file +is saved to the model's `config.json` file. **Note:** the model must have been +initialized with SparseAutoModelForCausalLM.from_pretrained() ```python -from sparseml.transformers.utils import SparseAutoModelForCausalLM from sparseml.transformers.compression import BitmaskConfig output_dir = "/PATH/TO/SAVE/COMPRESSED_MODEL" sparsity_config = BitmaskConfig() -SparseAutoModelForCausalLM.save_pretrained( - model, +model.save_pretrained( save_directory=output_dir, sparsity_config=sparsity_config, ) @@ -80,14 +79,13 @@ SparseAutoModelForCausalLM.save_pretrained( Saving a compressed model, inferring the config from the model attributes ```python -SparseAutoModelForCausalLM.save_compressed( +model.save_compressed( model, save_directory=output_dir, ) # alternative -SparseAutoModelForCausalLM.save_pretrained( - model, +model.save_pretrained( save_directory=output_dir, save_compressed=True ) @@ -97,19 +95,38 @@ Saving a model in the dense format, but still include a sparsity config in `conf with global sparsity and sparsity structure information ```python -from sparseml.transformers.utils import SparseAutoModelForCausalLM from sparseml.transformers.compression import DenseSparsityConfig -SparseAutoModelForCausalLM.save_pretrained( - model, +model.save_pretrained( save_directory=output_dir, sparsity_config=DenseSparsityConfig() ) ``` +## Enable Compression During One-Shot and Sparse Finetunining +Models that are saved in a supported compressed format on disk will automatically be +decompressed when loaded as input to `sparseml.transformers.oneshot` or +`sparseml.transformers.train` + +To enable compression on save after oneshot or finetuning simply add the +`save_compressed=True` argument to `sparseml.transformers.oneshot` or +`sparseml.transformers.train` + +```python +from sparseml.transformers import train + +train( + save_compressed=True, + model="neuralmagic/TinyLlama-1.1B-Chat-v1.0-pruned2.4", + recipe=RECIPE, + dataset=DATASET +) +``` + + ## Example Code -Loads a 50% sparse model, compresses it using the inferred bitmask compression, then +Loads a 60% sparse model, compresses it using the inferred bitmask compression, then reloads the compressed model. ```python @@ -126,11 +143,10 @@ with measure_cuda_memory() as m: model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="cuda:0") print(f"Load dense model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") -print(f"Sparsity config before compression: {model.sparsity_config}") +sparsity_config = getattr(model,"sparsity_config", None) +print(f"Sparsity config before compression: {sparsity_config}") with measure_cuda_memory() as m: - SparseAutoModelForCausalLM.save_compressed( - model, OUTPUT_PATH - ) + model.save_compressed(OUTPUT_PATH) print(f"Save compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") torch.cuda.set_device(1) @@ -139,5 +155,6 @@ with measure_cuda_memory() as m: OUTPUT_PATH, device_map="cuda:1" ) print(f"Load compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") -print(f"Sparsity config after compression: {model_again.sparsity_config}") +sparsity_config = getattr(model_again,"sparsity_config", None) +print(f"Sparsity config after compression: {sparsity_config}") ``` diff --git a/src/sparseml/transformers/compression/utils/__init__.py b/src/sparseml/transformers/compression/utils/__init__.py index 2543d7839bd..560435126ad 100644 --- a/src/sparseml/transformers/compression/utils/__init__.py +++ b/src/sparseml/transformers/compression/utils/__init__.py @@ -14,5 +14,6 @@ # flake8: noqa +from .compress_save import * from .helpers import * from .safetensors_load import * diff --git a/src/sparseml/transformers/compression/utils/compress_save.py b/src/sparseml/transformers/compression/utils/compress_save.py new file mode 100644 index 00000000000..18c00d360bb --- /dev/null +++ b/src/sparseml/transformers/compression/utils/compress_save.py @@ -0,0 +1,155 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import types +import weakref +from functools import wraps +from typing import Optional + +from transformers import PreTrainedModel +from transformers.file_utils import CONFIG_NAME + +from sparseml.transformers.compression.compressors import ModelCompressor +from sparseml.transformers.compression.config import CompressionConfig +from sparseml.transformers.utils.helpers import SPARSITY_CONFIG_NAME + + +__all__ = [ + "modify_save_pretrained", + "add_save_compressed_method", +] + + +def modify_save_pretrained(model: PreTrainedModel): + """ + Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that + supports compression + """ + + def save_pretrained_compressed(save_pretrained_method): + if getattr(save_pretrained_method, "_overridden", False): + # `model.save_pretrained` has already been replaced, return. + return save_pretrained_method + + # Keep a weak reference to the model class and unbound save_pretrained + # method so we can call the original + model_ref = weakref.ref(save_pretrained_method.__self__) + original_save_pretrained = save_pretrained_method.__func__ + model_class = model_ref().__class__ + del save_pretrained_method + + @wraps(original_save_pretrained) + def save_pretrained_wrapper( + save_directory: str, + sparsity_config: Optional[CompressionConfig] = None, + save_compressed: bool = False, + **kwargs, + ): + """ + Wrapper around PreTrainedModel.save_pretrained(), adds functionality for + saving models in a compressed format on disk. The compression format is + saved to the model's config file + + :param save_directory: output directory to save model to + :param sparsity_config: optional sparsity config to compress model with, + if no config is provided it will be inferred from the model + :param save_compresed: whether or not to compress the model on disk + :param kwargs: additional kwargs to pass on to model.save_pretrained + """ + model = model_ref() + + if sparsity_config is not None: + # if a sparsity config is provided, always save compressed + sparsity_config.fill_config_details(model) + save_compressed = True + elif save_compressed: + # try to infer a sparsity config from the model if none is provided + sparsity_config = CompressionConfig.infer_config_from_model( + model, compress=save_compressed + ) + + if sparsity_config is None: + # model is not sparse, save as dense + return original_save_pretrained.__get__(model, model_class)( + save_directory, **kwargs + ) + + # if we've gotten to this point we have a config so we can run compression + kwargs["safe_serialization"] = True + compressor = ModelCompressor.load_from_registry( + sparsity_config.format, config=sparsity_config + ) + + # state_dict gets passed in as a kwarg for FSDP models + state_dict = kwargs.get("state_dict", None) + if state_dict is None: + state_dict = model.state_dict() + + # make sure we're on the main process when saving + if state_dict is not None and len(state_dict) > 0: + compressed_state_dict = compressor.compress(state_dict) + kwargs["state_dict"] = compressed_state_dict + + original_save_pretrained.__get__(model, model_class)( + save_directory, **kwargs + ) + sparsity_config_data = sparsity_config.dict() + config_file_path = os.path.join(save_directory, CONFIG_NAME) + + # add the sparsity config to the model's config file + with open(config_file_path, "r") as config_file: + config_data = json.load(config_file) + config_data[SPARSITY_CONFIG_NAME] = sparsity_config_data + with open(config_file_path, "w") as config_file: + json.dump(config_data, config_file, indent=4, sort_keys=True) + + save_pretrained_wrapper._overriden = True + return save_pretrained_wrapper + + # wrap save_pretrained + model.save_pretrained = save_pretrained_compressed(model.save_pretrained) + + +def add_save_compressed_method(model: PreTrainedModel): + """ + Overrides an instance of PreTrainedModel to add a save_compressed method that + wraps PreTrainedModel.save_pretrained(). Requires modify_save_pretrained() has + already been run on the model instance + """ + + def save_compressed( + self, + save_directory: str, + sparsity_config: Optional[CompressionConfig] = None, + **kwargs, + ): + """ + Alias for PreTrainedModel.save_pretrained() that always saves in a + compressed format + + :param save_directory: output directory to save model to + :param sparsity_config: optional sparsity config to compress model with, if no + config is provided it will be inferred from the model + :param kwargs: additional kwargs to pass on to model.save_pretrained + """ + return self.save_pretrained( + save_directory=save_directory, + sparsity_config=sparsity_config, + save_compressed=True, + **kwargs, + ) + + model.save_compressed = types.MethodType(save_compressed, model) diff --git a/src/sparseml/transformers/finetune/runner.py b/src/sparseml/transformers/finetune/runner.py index 968b25c9726..e970e3b7264 100644 --- a/src/sparseml/transformers/finetune/runner.py +++ b/src/sparseml/transformers/finetune/runner.py @@ -194,6 +194,7 @@ def one_shot(self, stage: Optional[str] = None): save_path=self._output_dir, tokenizer=self.tokenizer, save_safetensors=self._training_args.save_safetensors, + save_compressed=self._training_args.save_compressed, ) def train(self, checkpoint: str, stage: Optional[str] = None): diff --git a/src/sparseml/transformers/finetune/session_mixin.py b/src/sparseml/transformers/finetune/session_mixin.py index 96464317b5c..72d18d98a9b 100644 --- a/src/sparseml/transformers/finetune/session_mixin.py +++ b/src/sparseml/transformers/finetune/session_mixin.py @@ -457,25 +457,29 @@ def save_model( :param output_dir: the path to save the recipes into """ - self._check_super_defined("save_model") - super().save_model(output_dir=output_dir, _internal_call=_internal_call) - if session_manager.active_session() is None: return # nothing to save if output_dir is None: output_dir = self.args.output_dir - # don't export the gathered model on checkpoints - if is_fsdp_model(self.model) and not _internal_call: + if not is_fsdp_model(self.model): + self.model.save_pretrained( + output_dir, + save_compressed=self.args.save_compressed, + safe_serialization=self.args.save_safetensors, + ) + else: # FSDP model save_pretrained_fsdp( model=self.model, accelerator=self.accelerator, output_dir=output_dir, + save_compressed=self.args.save_compressed, save_safetensors=self.metadata.get("save_safetensors", False), ) self.save_state() + self.tokenizer.save_pretrained(output_dir) if not _is_oneshot: # optimizer/scheduler not relevant to one-shot self.save_optimizer_and_scheduler(output_dir) diff --git a/src/sparseml/transformers/finetune/text_generation.py b/src/sparseml/transformers/finetune/text_generation.py index ca0fc5c75d4..76673a3a5ea 100644 --- a/src/sparseml/transformers/finetune/text_generation.py +++ b/src/sparseml/transformers/finetune/text_generation.py @@ -309,10 +309,6 @@ def main( if isinstance(tokenizer, str) or tokenizer is None: tokenizer = initialize_tokenizer_from_path(model_args, model, teacher) - # setup new SparseSession unless user requests otherwise - if training_args.clear_sparse_session: - session_manager.create_session() - session_manager.active_session().reset() session_manager.pre_initialize_structure(model=model, framework=Framework.pytorch) # intialize session manager @@ -381,6 +377,10 @@ def main( if training_args.do_predict: stage_runner.predict() + # Clean up the SparseSession before exit if requested + if training_args.clear_sparse_session: + session_manager.active_session().reset() + if __name__ == "__main__": apply() diff --git a/src/sparseml/transformers/finetune/training_args.py b/src/sparseml/transformers/finetune/training_args.py index 41d154ab28b..083fb5c5e2b 100644 --- a/src/sparseml/transformers/finetune/training_args.py +++ b/src/sparseml/transformers/finetune/training_args.py @@ -54,6 +54,10 @@ class TrainingArguments(HFTrainingArgs): ) }, ) + save_compressed: Optional[bool] = field( + default=False, + metadata={"help": "Whether to compress sparse models during save"}, + ) do_oneshot: Optional[bool] = field( default=False, metadata={"help": "Whether to run one-shot calibration"}, diff --git a/src/sparseml/transformers/utils/sparse_model.py b/src/sparseml/transformers/utils/sparse_model.py index a580424b2e4..58f9b9a5401 100644 --- a/src/sparseml/transformers/utils/sparse_model.py +++ b/src/sparseml/transformers/utils/sparse_model.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import json import logging import os from pathlib import Path @@ -29,15 +28,18 @@ AutoModelForTokenClassification, PreTrainedModel, ) -from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME +from transformers.file_utils import WEIGHTS_NAME from sparseml.pytorch.model_load.helpers import ( apply_recipe_structure_to_model, log_model_load, ) -from sparseml.transformers.compression import CompressionConfig, ModelCompressor -from sparseml.transformers.compression.utils import infer_compressor_from_model_config -from sparseml.transformers.utils.helpers import SPARSITY_CONFIG_NAME, resolve_recipe +from sparseml.transformers.compression.utils import ( + add_save_compressed_method, + infer_compressor_from_model_config, + modify_save_pretrained, +) +from sparseml.transformers.utils.helpers import resolve_recipe from sparseml.utils import download_zoo_training_dir from sparseml.utils.fsdp.context import main_process_first_context @@ -111,6 +113,10 @@ def skip(*args, **kwargs): ) logger.setLevel(level=restore_log_level) + # override the PreTrainedModel instance with compression save functions + modify_save_pretrained(model) + add_save_compressed_method(model) + # If model is compressed on disk, decompress and load the weights if compressor is not None: compressor.overwrite_weights( @@ -126,85 +132,6 @@ def skip(*args, **kwargs): ) return model - @staticmethod - def save_pretrained( - model: PreTrainedModel, - save_directory: str, - sparsity_config: Optional[CompressionConfig] = None, - save_compressed: bool = False, - **kwargs, - ): - """ - Wrapper around PreTrainedModel.save_pretrained(), adds functionality for - saving models in a compressed format on disk. The compression format is - saved to the model's config file - - :param model: transformers model to save - :param save_directory: output directory to save model to - :param sparsity_config: optional sparsity config to compress model with, if no - config is provided it will be inferred from the model - :param save_compresed: whether or not to compress the model on disk - :param kwargs: additional kwargs to pass on to model.save_pretrained - """ - if sparsity_config is not None: - # if a sparsity config is provided, always save compressed - sparsity_config.fill_config_details(model) - save_compressed = True - elif save_compressed: - # try to infer a sparsity config from the model if none is provided - sparsity_config = CompressionConfig.infer_config_from_model( - model, compress=save_compressed - ) - - if sparsity_config is None: - # model is not sparse, save as dense - return model.save_pretrained(save_directory, **kwargs) - - # if we've gotten to this point we can run compression since we have a config - kwargs["safe_serialization"] = True - compressor = ModelCompressor.load_from_registry( - sparsity_config.format, config=sparsity_config - ) - - compressed_state_dict = compressor.compress(model.state_dict()) - kwargs["state_dict"] = compressed_state_dict - - model.save_pretrained(save_directory, **kwargs) - sparsity_config_data = sparsity_config.dict() - config_file_path = os.path.join(save_directory, CONFIG_NAME) - - # add the sparsity config to the model's config file - with open(config_file_path, "r") as config_file: - config_data = json.load(config_file) - config_data[SPARSITY_CONFIG_NAME] = sparsity_config_data - with open(config_file_path, "w") as config_file: - json.dump(config_data, config_file, indent=4, sort_keys=True) - - @staticmethod - def save_compressed( - model: PreTrainedModel, - save_directory: str, - sparsity_config: Optional[CompressionConfig] = None, - **kwargs, - ): - """ - Alias for SparseAutoModelForCausalLM.save_pretrained() that always saves in a - compressed format - - :param model: transformers model to save - :param save_directory: output directory to save model to - :param sparsity_config: optional sparsity config to compress model with, if no - config is provided it will be inferred from the model - :param kwargs: additional kwargs to pass on to model.save_pretrained - """ - return SparseAutoModelForCausalLM.save_pretrained( - model=model, - save_directory=save_directory, - sparsity_config=sparsity_config, - save_compressed=True, - **kwargs, - ) - class SparseAutoModel: """ diff --git a/src/sparseml/utils/fsdp/helpers.py b/src/sparseml/utils/fsdp/helpers.py index abae74f612c..d2def7fef39 100644 --- a/src/sparseml/utils/fsdp/helpers.py +++ b/src/sparseml/utils/fsdp/helpers.py @@ -140,7 +140,13 @@ def find_and_move_state_dicts_to_cpu(output_dir: str): _LOGGER.info(f"Moved state dict {model_file} to cpu") -def save_pretrained_fsdp(model, accelerator, output_dir, save_safetensors: bool = True): +def save_pretrained_fsdp( + model, + accelerator, + output_dir, + save_safetensors: bool = True, + save_compressed: bool = False, +): full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) """ Gathers the full FSDP state dict of the model onto rank0 GPU, then uses it to save @@ -150,6 +156,7 @@ def save_pretrained_fsdp(model, accelerator, output_dir, save_safetensors: bool :param accelerator: Accelerator instance used to perform unwrapping :param output_dir: where to save output model :param save_safetensors: True to safe in safetensors format, otherwise .bin + :param save_compressed: whether to compress sparse weights on disk """ with FullyShardedDataParallel.state_dict_type( model, StateDictType.FULL_STATE_DICT, full_state_dict_config @@ -161,6 +168,7 @@ def save_pretrained_fsdp(model, accelerator, output_dir, save_safetensors: bool is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=state_dict, + save_compressed=save_compressed, safe_serialization=save_safetensors, ) diff --git a/tests/sparseml/transformers/compression/test_sparse_auto.py b/tests/sparseml/transformers/compression/test_sparse_auto.py index b88881e0dff..bd621eff9c9 100644 --- a/tests/sparseml/transformers/compression/test_sparse_auto.py +++ b/tests/sparseml/transformers/compression/test_sparse_auto.py @@ -30,15 +30,16 @@ @pytest.mark.parametrize( - "compressed,config,dtype", + "compressed,config,dtype,use_compress", [ - [True, None, torch.float32], - [False, DenseSparsityConfig(), torch.float16], - [True, BitmaskConfig(), torch.bfloat16], - [False, BitmaskConfig(), torch.float32], + [True, None, torch.float32, False], + [False, DenseSparsityConfig(), torch.float16, False], + [True, BitmaskConfig(), torch.bfloat16, False], + [True, BitmaskConfig(), torch.bfloat16, True], + [False, BitmaskConfig(), torch.float32, False], ], ) -def test_sparse_model_reload(compressed, config, dtype, tmp_path): +def test_sparse_model_reload(compressed, config, dtype, use_compress, tmp_path): recipe_str = "tests/sparseml/transformers/obcq/test_tiny2.yaml" model_path = "Xenova/llama2.c-stories15M" device = "cuda:0" @@ -72,12 +73,15 @@ def test_sparse_model_reload(compressed, config, dtype, tmp_path): inferred_structure = CompressionConfig.infer_sparsity_structure() assert inferred_structure == "0:0" - SparseAutoModelForCausalLM.save_pretrained( - model, - tmp_path / "compress_out", - sparsity_config=config, - save_compressed=compressed, - ) + if use_compress: + model.save_compressed(tmp_path / "compress_out", sparsity_config=config) + else: + model.save_pretrained( + tmp_path / "compress_out", + sparsity_config=config, + save_compressed=compressed, + ) + config = AutoConfig.from_pretrained(tmp_path / "compress_out") sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None) assert ( @@ -115,7 +119,7 @@ def test_dense_model_save(tmp_path): inferred_structure = CompressionConfig.infer_sparsity_structure() assert inferred_structure == "unstructured" - SparseAutoModelForCausalLM.save_pretrained(model, tmp_path / "dense_out") + model.save_pretrained(tmp_path / "dense_out") config = AutoConfig.from_pretrained(tmp_path / "dense_out") sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None) assert sparsity_config is None