From 65be668f1073f88cf25f28766b05f3168e5a2987 Mon Sep 17 00:00:00 2001 From: Oleg S <97077423+RobotSail@users.noreply.github.com> Date: Fri, 11 Oct 2024 17:41:56 -0400 Subject: [PATCH 1/4] properly initialize lora model --- src/instructlab/training/main_ds.py | 5 ++--- src/instructlab/training/setup_accelerator.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index c5cdb2ba..1cd346cd 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -211,9 +211,8 @@ def setup_model(args, tokenizer, train_loader, grad_accum): task_type="CAUSAL_LM", target_modules=args.lora_target_modules, ) - model = prepare_peft_model( - model, peft_config, gradient_checkpointing=not args.is_granite - ) + from peft import LoraModel + model = LoraModel(model, peft_config, "default") elif not args.is_granite: model.gradient_checkpointing_enable() diff --git a/src/instructlab/training/setup_accelerator.py b/src/instructlab/training/setup_accelerator.py index 33972b59..1eb5cc04 100644 --- a/src/instructlab/training/setup_accelerator.py +++ b/src/instructlab/training/setup_accelerator.py @@ -58,22 +58,30 @@ def get_fsdp_config(args, model): block_name = model._no_split_modules[0] - fsdp_plugin = FullyShardedDataParallelPlugin( - auto_wrap_policy=partial( + wrap_policy = None + if args.lora_r > 0: + from peft.utils.other import fsdp_auto_wrap_policy + wrap_policy = fsdp_auto_wrap_policy(model) + else: + wrap_policy = partial( transformer_auto_wrap_policy, transformer_layer_cls={ get_module_class_from_name(model, block_name), }, - ), + ) + + fsdp_plugin = FullyShardedDataParallelPlugin( + auto_wrap_policy=wrap_policy, limit_all_gathers=True, mixed_precision_policy=MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ), - backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + backward_prefetch=BackwardPrefetch.BACKWARD_POST, sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy], cpu_offload=CPUOffload(args.cpu_offload_params_fsdp), + use_orig_params=False, ) return fsdp_plugin From 92392da52a816bf7a68bab3e9917bcbe8ef62321 Mon Sep 17 00:00:00 2001 From: Oleg S <97077423+RobotSail@users.noreply.github.com> Date: Fri, 11 Oct 2024 17:34:33 -0400 Subject: [PATCH 2/4] wip commit --- src/instructlab/training/internal/__init__.py | 15 +++ .../training/internal/accelerator.py | 94 +++++++++++++++++++ src/instructlab/training/main_ds.py | 30 +++++- src/instructlab/training/setup_accelerator.py | 18 +++- src/instructlab/training/utils.py | 40 +++++--- 5 files changed, 178 insertions(+), 19 deletions(-) create mode 100644 src/instructlab/training/internal/__init__.py create mode 100644 src/instructlab/training/internal/accelerator.py diff --git a/src/instructlab/training/internal/__init__.py b/src/instructlab/training/internal/__init__.py new file mode 100644 index 00000000..82dddd10 --- /dev/null +++ b/src/instructlab/training/internal/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 + +""" +This module is intended to house INTERNAL functions & symbols. Things in this module are considered +as private API and are thus not supported. + +The signatures here are also not considered in the versioning scheme under semver - and +may change at any time. + +By using this module, you assume full responsibility of maintenance. +""" + +__all__ = ("__SuperAccelerator") + +from .accelerator import __SuperAccelerator \ No newline at end of file diff --git a/src/instructlab/training/internal/accelerator.py b/src/instructlab/training/internal/accelerator.py new file mode 100644 index 00000000..f06a414d --- /dev/null +++ b/src/instructlab/training/internal/accelerator.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 + +from copy import deepcopy +from typing import Tuple, Any + +from accelerate import Accelerator, DistributedType +from peft import LoraModel +from torch import distributed as dist +from torch import nn +from transformers import PreTrainedModel + +def wraps(module: nn.Module, wrapped_classes: Tuple[Any]) -> bool: + """Checks if a module or its children are an instance of one of the provided classes. + + Args: + module (nn.Module): A PyTorch module. + wrapped_classes(Tuple): A tuple of potential classes the module could be. + + Returns: + bool: True if the module or any of its children are instances of `transformers.PreTrainedModel`, False otherwise. + """ + if isinstance(module, wrapped_classes): + return True + + for m in module.children(): + if wraps(m, wrapped_classes): + return True + + return False + + + +class __SuperAccelerator(Accelerator): + """ + Custom InstructLab Accelerator class that extends the `accelerate.Accelerator` object. + We extend this class to modify some functionality embedded in the existing Accelerator + which prevents us from being able to properly save LoRA models when using FSDP as the + distributed backend. + + Warning: This is NOT a public API and is not intended to be supported beyond its + internal usage in this library. Use at your own discretion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._cpu_model: PreTrainedModel = None + self._is_lora = False + + def prepare(self, *args, **kwargs): + # Extract out the model to make a copy on the cpu. + # Make sure this only happens once per object lifecycle - we call accelerator.prepare + # several times. + num_times_found = 0 + if self.distributed_type == DistributedType.FSDP and not self.cpu_model and self.is_main_process: + for arg in args: + if isinstance(arg, nn.Module) and wraps(arg, PreTrainedModel) and wraps(arg, LoraModel): + self._is_lora = True + num_times_found += 1 + # from IPython import embed; embed() + # cpu model setter logic will handle validation - but this may be a stupid idea and + # we should instead handle it here + # self.cpu_model = arg + + + print(f'number of times we found a lora pretrained arg: {num_times_found}') + dist.breakpoint() + dist.barrier() + return super().prepare(*args, **kwargs) + + @property + def cpu_model(self) -> nn.Module: + if self.is_main_process: + return self._cpu_model + return None + + @cpu_model.setter + def cpu_model(self, model: nn.Module) -> nn.Module | None: + """ + Given a model **BEFORE** it is sent to FSDP, we copy it and keep it on the CPU. + The model is only stored for the main process, so calling on a non-main will return None. + """ + if not self.is_main_process: + # only one process in the group should ever store the model + return + + # ensure the model is not on the GPU yet + if any(p.is_cuda for p in model.parameters()): + # while it is POSSIBLE to copy a model from the GPU to the CPU, we should avoid doing this + # due to potential memory constraints. + # + # As long as we correctly prepare the model through Accelerate, we should not hit this. + raise RuntimeError('Copying a model from the GPU to the CPU is not supported.') + + self._cpu_model = deepcopy(model) \ No newline at end of file diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 1cd346cd..a45154ea 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -18,7 +18,7 @@ # pylint: disable=no-name-in-module from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM from tqdm import tqdm -from transformers import AutoModelForCausalLM, get_scheduler +from transformers import AutoModelForCausalLM, get_scheduler, LlamaForCausalLM, PreTrainedModel import torch import torch.distributed @@ -237,6 +237,15 @@ def make_inputs_require_grad(module, input, output): accelerator = setup_accelerator(args, model, grad_accum) if args.distributed_training_framework == DistributedBackend.FSDP.value: model = accelerator.prepare(model) + # print(model) + # print(f"model is instance of Pretrainemodel: {isinstance(model, PreTrainedModel)}") + + # if accelerator.is_main_process: + # from IPython import embed; embed() + + torch.distributed.barrier() + + optimizer = setup_optimizer(args, model) lr_scheduler = get_scheduler( @@ -525,9 +534,11 @@ def main(args): tokenizer = setup_tokenizer(args.model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE) # device = torch.device("cuda", args.local_rank) + local_rank = int(os.environ['LOCAL_RANK']) + #### distributed init ##### - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - args.local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + args.local_rank = local_rank torch.distributed.init_process_group("nccl") args.global_rank = torch.distributed.get_rank() tensor = torch.ByteTensor([False]).cuda() @@ -540,6 +551,8 @@ def main(args): mock_len=args.mock_len, ) + # check this across all processes - in theory this shouldnt matter but doing it anyway cause its cool 😎 + multipack_did_fail = torch.ByteTensor([False]).cuda() try: packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum( num_gpus=torch.distributed.get_world_size(), @@ -551,7 +564,7 @@ def main(args): seed=args.seed, ) args.sampler = "multipack" - except RuntimeError as e: + except (RuntimeError, ZeroDivisionError) as e: if os.environ["LOCAL_RANK"] == "0": print(f"\033[38;5;120m{e}\033[0m") @@ -559,7 +572,14 @@ def main(args): # NOTE: packing max batch len will not be used packing_max_batch_len = None grad_accum = 1 - args.sampler = "distributed" + multipack_did_fail[0] = 1 + finally: + torch.distributed.all_reduce(multipack_did_fail, op=torch.distributed.ReduceOp.MAX) + if multipack_did_fail[0]: + print('multipack was found to fail on at least one process, falling back to naive sampling') + args.sampler = "distributed" + + args.samples_per_gpu = ( args.effective_batch_size // grad_accum // torch.distributed.get_world_size() diff --git a/src/instructlab/training/setup_accelerator.py b/src/instructlab/training/setup_accelerator.py index 1eb5cc04..918dbef1 100644 --- a/src/instructlab/training/setup_accelerator.py +++ b/src/instructlab/training/setup_accelerator.py @@ -1,19 +1,30 @@ # Standard +from copy import deepcopy from functools import partial +import os +from typing import Callable, Tuple, Any, Union # Third Party -from accelerate import Accelerator -from torch.distributed.fsdp import ( # FullyShardedDataParallel as FSDP, +from accelerate import Accelerator, DistributedType +from peft.tuners.lora.model import LoraModel +from peft.utils.other import fsdp_auto_wrap_policy +from torch.distributed.fsdp import ( BackwardPrefetch, MixedPrecision, ShardingStrategy, ) +from torch import nn from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy import torch +from torch import distributed as dist +from transformers import PreTrainedModel # First Party from instructlab.training.config import DeepSpeedOptions from instructlab.training.utils import get_module_class_from_name, patch_target_module +from instructlab.training.internal import __SuperAccelerator + + def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions): @@ -114,12 +125,13 @@ def setup_accelerator(args, model, grad_accum): elif args.distributed_training_framework == "fsdp": accel_args = { "fsdp_plugin": get_fsdp_config(args, model), + 'mixed_precision': "bf16", } else: raise ValueError( f"Unknown sharding framework: {args.distributed_training_framework}" ) - accelerator = Accelerator( + accelerator = __SuperAccelerator( **accel_args, ) accelerator.even_batches = False diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 6d79d897..9952b0a0 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -21,7 +21,7 @@ # Third Party # pylint: disable=no-name-in-module -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from instructlab.dolomite.hf_models import ( GPTDolomiteConfig, export_to_huggingface, @@ -40,6 +40,8 @@ import torch import torch.nn.functional as F +from instructlab.training.internal import __SuperAccelerator + def retrieve_chat_template(chat_tmpl_path): try: @@ -655,13 +657,29 @@ def skip_precheck_loops(): accelerator.get_state_dict = old_get_state +def save_hf_format_accelerate_lora( + args, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + accelerator: __SuperAccelerator, +): + print('SAVING LORA') + if accelerator.is_main_process: + accelerator.save_lora_fsdp( + model, + save_directory=args.output_dir, + max_shard_size="5GB", + safe_serialization=True + ) + tokenizer.save_pretrained(args.output_dir) + dist.barrier() def save_hf_format_accelerate( args, - model, - tokenizer, - accelerator: Accelerator, - samples_seen, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + accelerator: __SuperAccelerator, + samples_seen: int, convert_granite=True, is_lora=False, ): @@ -681,12 +699,13 @@ def save_hf_format_accelerate( CONFIG_NAME = "config.json" output_config_file = output_dir / CONFIG_NAME - get_state_dict_unpatched = accelerator.get_state_dict - - def _get_state_dict_patched(model, unwrap=False): - return get_state_dict_unpatched(model, unwrap=unwrap) + if is_lora and accelerator.distributed_type == DistributedType.FSDP: + save_hf_format_accelerate_lora(args, model, tokenizer, accelerator) + return - accelerator.get_state_dict = _get_state_dict_patched + if is_lora: + model.module.merge_adapter() + model_state = model.module.state_dict() if accelerator.is_main_process: if is_lora: @@ -735,7 +754,6 @@ def _get_state_dict_patched(model, unwrap=False): log_rank_0(f"saving took {time.time() - start} seconds") dist.barrier() - accelerator.get_state_dict = get_state_dict_unpatched # this is native deepspeed saving with optimizer, scheduler From b9c137d0cd5efbf7dce916acdb7992b4eca680e4 Mon Sep 17 00:00:00 2001 From: Oleg S <97077423+RobotSail@users.noreply.github.com> Date: Fri, 11 Oct 2024 17:50:13 -0400 Subject: [PATCH 3/4] wip fix --- src/instructlab/training/internal/accelerator.py | 2 +- src/instructlab/training/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/instructlab/training/internal/accelerator.py b/src/instructlab/training/internal/accelerator.py index f06a414d..949c0c17 100644 --- a/src/instructlab/training/internal/accelerator.py +++ b/src/instructlab/training/internal/accelerator.py @@ -51,7 +51,7 @@ def prepare(self, *args, **kwargs): # Make sure this only happens once per object lifecycle - we call accelerator.prepare # several times. num_times_found = 0 - if self.distributed_type == DistributedType.FSDP and not self.cpu_model and self.is_main_process: + if self.distributed_type == DistributedType.FSDP and self.is_main_process and not self._cpu_model: for arg in args: if isinstance(arg, nn.Module) and wraps(arg, PreTrainedModel) and wraps(arg, LoraModel): self._is_lora = True diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 9952b0a0..08139366 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -35,7 +35,7 @@ apply_activation_checkpointing, checkpoint_wrapper, ) -from transformers import PreTrainedModel +from transformers import PreTrainedModel, PreTrainedTokenizer import numpy as np import torch import torch.nn.functional as F From ff4d5dfe46e9a2ddfa83a124183ae1d30f6da2af Mon Sep 17 00:00:00 2001 From: Oleg S <97077423+RobotSail@users.noreply.github.com> Date: Mon, 14 Oct 2024 12:06:00 -0400 Subject: [PATCH 4/4] wip - saving logic --- .../training/internal/accelerator.py | 71 +++++++++++++++++-- src/instructlab/training/utils.py | 25 ++++--- 2 files changed, 80 insertions(+), 16 deletions(-) diff --git a/src/instructlab/training/internal/accelerator.py b/src/instructlab/training/internal/accelerator.py index 949c0c17..211a4e04 100644 --- a/src/instructlab/training/internal/accelerator.py +++ b/src/instructlab/training/internal/accelerator.py @@ -7,7 +7,9 @@ from peft import LoraModel from torch import distributed as dist from torch import nn +from torch.cuda import empty_cache from transformers import PreTrainedModel +import torch def wraps(module: nn.Module, wrapped_classes: Tuple[Any]) -> bool: """Checks if a module or its children are an instance of one of the provided classes. @@ -51,20 +53,22 @@ def prepare(self, *args, **kwargs): # Make sure this only happens once per object lifecycle - we call accelerator.prepare # several times. num_times_found = 0 + using_lora = torch.ByteTensor([False]).cuda() if self.distributed_type == DistributedType.FSDP and self.is_main_process and not self._cpu_model: for arg in args: if isinstance(arg, nn.Module) and wraps(arg, PreTrainedModel) and wraps(arg, LoraModel): - self._is_lora = True + using_lora[0] = True num_times_found += 1 - # from IPython import embed; embed() # cpu model setter logic will handle validation - but this may be a stupid idea and # we should instead handle it here - # self.cpu_model = arg - - + self.cpu_model = arg + break + print(f'number of times we found a lora pretrained arg: {num_times_found}') - dist.breakpoint() dist.barrier() + dist.all_reduce(using_lora, op=dist.ReduceOp.MAX) + if using_lora[0]: + self._is_lora = True return super().prepare(*args, **kwargs) @property @@ -91,4 +95,57 @@ def cpu_model(self, model: nn.Module) -> nn.Module | None: # As long as we correctly prepare the model through Accelerate, we should not hit this. raise RuntimeError('Copying a model from the GPU to the CPU is not supported.') - self._cpu_model = deepcopy(model) \ No newline at end of file + self._cpu_model = deepcopy(model) + + def save_lora_fsdp(self, model: nn.Module, *args, **kwargs) -> None: + """Extension of the `accelerate.Accelerator.save_model` method. + + This provides the ability to save a model in SafeTensors format when training a LoRA with FSDP. + + Args: + model (nn.Module): The accelerator-wrapped model to save. + """ + + + if self.distributed_type != DistributedType.FSDP: + raise RuntimeError('`__SuperAccelerator.save_fsdp_lora` was called when FSDP was not being used.') + if not self._is_lora: + raise RuntimeError('`__SuperAccelerator.save_fsdp_lora` was called but was not configured to use LoRA') + + print('GETTING OLD MODEL STATE DICT') + model_state = self.get_state_dict(model, unwrap=True) + print('COPYING CPU MODEL') + if self.is_main_process: + tmp_model: LoraModel = deepcopy(self.cpu_model) + print('LOADING STATE DICT INTO TEMP MODEL') + tmp_model.load_state_dict(model_state) + print('MERGING & UNLOADING TEMP MODEL') + tmp_model.merge_and_unload(True) + print('GETTING STATE DICT FROM TEMP MODEL') + model_state = tmp_model.state_dict() + + old_get_state_dict = self.get_state_dict + def _custom_get_state_dict(ref: Accelerator, *args, **kwargs): + """ + Custom function to trick `accelerate.Accelerator` to get a state dict that will work + when training with LoRA & FSDP. + """ + print('RETURNING TEMP MODEL STATE DICT INTO ACCELERATORS SAVE MODEL') + return model_state + + print('OVERWRITING get_state_dict WITH CUSTOM FN') + self.get_state_dict = _custom_get_state_dict + print('CALLING ACCELERATOR SAVE_MODEL') + self.save_model(model, *args, **kwargs) + print('RETURNING get_state_dict FUNCTION TO OLD') + self.get_state_dict = old_get_state_dict + + print('DELETING TMP MODEL') + del tmp_model + empty_cache() + + + + print('SAVED SUCCESSFULLY') + # from IPython import embed; embed() + diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 08139366..17e32261 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -662,16 +662,23 @@ def save_hf_format_accelerate_lora( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, accelerator: __SuperAccelerator, + # config_file: Path ): print('SAVING LORA') - if accelerator.is_main_process: - accelerator.save_lora_fsdp( - model, - save_directory=args.output_dir, - max_shard_size="5GB", - safe_serialization=True - ) - tokenizer.save_pretrained(args.output_dir) + # if accelerator.is_main_process: + output_dir = Path(args.output_dir) + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + accelerator.save_lora_fsdp( + model, + save_directory=args.output_dir, + max_shard_size="5GB", + safe_serialization=True + ) + config_file_out = Path(f'{args.output_dir}/config.json') + + model.config.to_json_file(config_file_out) + tokenizer.save_pretrained(args.output_dir) dist.barrier() def save_hf_format_accelerate( @@ -700,7 +707,7 @@ def save_hf_format_accelerate( output_config_file = output_dir / CONFIG_NAME if is_lora and accelerator.distributed_type == DistributedType.FSDP: - save_hf_format_accelerate_lora(args, model, tokenizer, accelerator) + save_hf_format_accelerate_lora(args, model, tokenizer, accelerator ) return if is_lora: