From ff4d5dfe46e9a2ddfa83a124183ae1d30f6da2af Mon Sep 17 00:00:00 2001 From: Oleg S <97077423+RobotSail@users.noreply.github.com> Date: Mon, 14 Oct 2024 12:06:00 -0400 Subject: [PATCH] wip - saving logic --- .../training/internal/accelerator.py | 71 +++++++++++++++++-- src/instructlab/training/utils.py | 25 ++++--- 2 files changed, 80 insertions(+), 16 deletions(-) diff --git a/src/instructlab/training/internal/accelerator.py b/src/instructlab/training/internal/accelerator.py index 949c0c17..211a4e04 100644 --- a/src/instructlab/training/internal/accelerator.py +++ b/src/instructlab/training/internal/accelerator.py @@ -7,7 +7,9 @@ from peft import LoraModel from torch import distributed as dist from torch import nn +from torch.cuda import empty_cache from transformers import PreTrainedModel +import torch def wraps(module: nn.Module, wrapped_classes: Tuple[Any]) -> bool: """Checks if a module or its children are an instance of one of the provided classes. @@ -51,20 +53,22 @@ def prepare(self, *args, **kwargs): # Make sure this only happens once per object lifecycle - we call accelerator.prepare # several times. num_times_found = 0 + using_lora = torch.ByteTensor([False]).cuda() if self.distributed_type == DistributedType.FSDP and self.is_main_process and not self._cpu_model: for arg in args: if isinstance(arg, nn.Module) and wraps(arg, PreTrainedModel) and wraps(arg, LoraModel): - self._is_lora = True + using_lora[0] = True num_times_found += 1 - # from IPython import embed; embed() # cpu model setter logic will handle validation - but this may be a stupid idea and # we should instead handle it here - # self.cpu_model = arg - - + self.cpu_model = arg + break + print(f'number of times we found a lora pretrained arg: {num_times_found}') - dist.breakpoint() dist.barrier() + dist.all_reduce(using_lora, op=dist.ReduceOp.MAX) + if using_lora[0]: + self._is_lora = True return super().prepare(*args, **kwargs) @property @@ -91,4 +95,57 @@ def cpu_model(self, model: nn.Module) -> nn.Module | None: # As long as we correctly prepare the model through Accelerate, we should not hit this. raise RuntimeError('Copying a model from the GPU to the CPU is not supported.') - self._cpu_model = deepcopy(model) \ No newline at end of file + self._cpu_model = deepcopy(model) + + def save_lora_fsdp(self, model: nn.Module, *args, **kwargs) -> None: + """Extension of the `accelerate.Accelerator.save_model` method. + + This provides the ability to save a model in SafeTensors format when training a LoRA with FSDP. + + Args: + model (nn.Module): The accelerator-wrapped model to save. + """ + + + if self.distributed_type != DistributedType.FSDP: + raise RuntimeError('`__SuperAccelerator.save_fsdp_lora` was called when FSDP was not being used.') + if not self._is_lora: + raise RuntimeError('`__SuperAccelerator.save_fsdp_lora` was called but was not configured to use LoRA') + + print('GETTING OLD MODEL STATE DICT') + model_state = self.get_state_dict(model, unwrap=True) + print('COPYING CPU MODEL') + if self.is_main_process: + tmp_model: LoraModel = deepcopy(self.cpu_model) + print('LOADING STATE DICT INTO TEMP MODEL') + tmp_model.load_state_dict(model_state) + print('MERGING & UNLOADING TEMP MODEL') + tmp_model.merge_and_unload(True) + print('GETTING STATE DICT FROM TEMP MODEL') + model_state = tmp_model.state_dict() + + old_get_state_dict = self.get_state_dict + def _custom_get_state_dict(ref: Accelerator, *args, **kwargs): + """ + Custom function to trick `accelerate.Accelerator` to get a state dict that will work + when training with LoRA & FSDP. + """ + print('RETURNING TEMP MODEL STATE DICT INTO ACCELERATORS SAVE MODEL') + return model_state + + print('OVERWRITING get_state_dict WITH CUSTOM FN') + self.get_state_dict = _custom_get_state_dict + print('CALLING ACCELERATOR SAVE_MODEL') + self.save_model(model, *args, **kwargs) + print('RETURNING get_state_dict FUNCTION TO OLD') + self.get_state_dict = old_get_state_dict + + print('DELETING TMP MODEL') + del tmp_model + empty_cache() + + + + print('SAVED SUCCESSFULLY') + # from IPython import embed; embed() + diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 08139366..17e32261 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -662,16 +662,23 @@ def save_hf_format_accelerate_lora( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, accelerator: __SuperAccelerator, + # config_file: Path ): print('SAVING LORA') - if accelerator.is_main_process: - accelerator.save_lora_fsdp( - model, - save_directory=args.output_dir, - max_shard_size="5GB", - safe_serialization=True - ) - tokenizer.save_pretrained(args.output_dir) + # if accelerator.is_main_process: + output_dir = Path(args.output_dir) + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + accelerator.save_lora_fsdp( + model, + save_directory=args.output_dir, + max_shard_size="5GB", + safe_serialization=True + ) + config_file_out = Path(f'{args.output_dir}/config.json') + + model.config.to_json_file(config_file_out) + tokenizer.save_pretrained(args.output_dir) dist.barrier() def save_hf_format_accelerate( @@ -700,7 +707,7 @@ def save_hf_format_accelerate( output_config_file = output_dir / CONFIG_NAME if is_lora and accelerator.distributed_type == DistributedType.FSDP: - save_hf_format_accelerate_lora(args, model, tokenizer, accelerator) + save_hf_format_accelerate_lora(args, model, tokenizer, accelerator ) return if is_lora: