Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] LoRA + FSDP 2 #269

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/instructlab/training/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0

"""
This module is intended to house INTERNAL functions & symbols. Things in this module are considered
as private API and are thus not supported.

The signatures here are also not considered in the versioning scheme under semver - and
may change at any time.

By using this module, you assume full responsibility of maintenance.
"""

__all__ = ("__SuperAccelerator")

Check failure on line 13 in src/instructlab/training/internal/__init__.py

View workflow job for this annotation

GitHub Actions / pylint

E0605: Invalid format for __all__, must be tuple or list (invalid-all-format)

from .accelerator import __SuperAccelerator
151 changes: 151 additions & 0 deletions src/instructlab/training/internal/accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Tuple, Any

from accelerate import Accelerator, DistributedType
from peft import LoraModel
from torch import distributed as dist
from torch import nn
from torch.cuda import empty_cache
from transformers import PreTrainedModel
import torch

def wraps(module: nn.Module, wrapped_classes: Tuple[Any]) -> bool:
"""Checks if a module or its children are an instance of one of the provided classes.

Args:
module (nn.Module): A PyTorch module.
wrapped_classes(Tuple): A tuple of potential classes the module could be.

Returns:
bool: True if the module or any of its children are instances of `transformers.PreTrainedModel`, False otherwise.
"""
if isinstance(module, wrapped_classes):
return True

for m in module.children():
if wraps(m, wrapped_classes):
return True

return False



class __SuperAccelerator(Accelerator):
"""
Custom InstructLab Accelerator class that extends the `accelerate.Accelerator` object.
We extend this class to modify some functionality embedded in the existing Accelerator
which prevents us from being able to properly save LoRA models when using FSDP as the
distributed backend.

Warning: This is NOT a public API and is not intended to be supported beyond its
internal usage in this library. Use at your own discretion.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._cpu_model: PreTrainedModel = None
self._is_lora = False

def prepare(self, *args, **kwargs):
# Extract out the model to make a copy on the cpu.
# Make sure this only happens once per object lifecycle - we call accelerator.prepare
# several times.
num_times_found = 0
using_lora = torch.ByteTensor([False]).cuda()
if self.distributed_type == DistributedType.FSDP and self.is_main_process and not self._cpu_model:
for arg in args:
if isinstance(arg, nn.Module) and wraps(arg, PreTrainedModel) and wraps(arg, LoraModel):
using_lora[0] = True

Check failure on line 60 in src/instructlab/training/internal/accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

E1137: 'using_lora' does not support item assignment (unsupported-assignment-operation)
num_times_found += 1
# cpu model setter logic will handle validation - but this may be a stupid idea and
# we should instead handle it here
self.cpu_model = arg
break

print(f'number of times we found a lora pretrained arg: {num_times_found}')
dist.barrier()
dist.all_reduce(using_lora, op=dist.ReduceOp.MAX)
if using_lora[0]:

Check failure on line 70 in src/instructlab/training/internal/accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

E1136: Value 'using_lora' is unsubscriptable (unsubscriptable-object)
self._is_lora = True
return super().prepare(*args, **kwargs)

@property
def cpu_model(self) -> nn.Module:
if self.is_main_process:
return self._cpu_model
return None

@cpu_model.setter
def cpu_model(self, model: nn.Module) -> nn.Module | None:
"""
Given a model **BEFORE** it is sent to FSDP, we copy it and keep it on the CPU.
The model is only stored for the main process, so calling on a non-main will return None.
"""
if not self.is_main_process:
# only one process in the group should ever store the model
return

# ensure the model is not on the GPU yet
if any(p.is_cuda for p in model.parameters()):
# while it is POSSIBLE to copy a model from the GPU to the CPU, we should avoid doing this
# due to potential memory constraints.
#
# As long as we correctly prepare the model through Accelerate, we should not hit this.
raise RuntimeError('Copying a model from the GPU to the CPU is not supported.')

self._cpu_model = deepcopy(model)

def save_lora_fsdp(self, model: nn.Module, *args, **kwargs) -> None:
"""Extension of the `accelerate.Accelerator.save_model` method.

This provides the ability to save a model in SafeTensors format when training a LoRA with FSDP.

Args:
model (nn.Module): The accelerator-wrapped model to save.
"""


if self.distributed_type != DistributedType.FSDP:
raise RuntimeError('`__SuperAccelerator.save_fsdp_lora` was called when FSDP was not being used.')
if not self._is_lora:
raise RuntimeError('`__SuperAccelerator.save_fsdp_lora` was called but was not configured to use LoRA')

print('GETTING OLD MODEL STATE DICT')
model_state = self.get_state_dict(model, unwrap=True)
print('COPYING CPU MODEL')
if self.is_main_process:
tmp_model: LoraModel = deepcopy(self.cpu_model)
print('LOADING STATE DICT INTO TEMP MODEL')
tmp_model.load_state_dict(model_state)
print('MERGING & UNLOADING TEMP MODEL')
tmp_model.merge_and_unload(True)
print('GETTING STATE DICT FROM TEMP MODEL')
model_state = tmp_model.state_dict()

old_get_state_dict = self.get_state_dict
def _custom_get_state_dict(ref: Accelerator, *args, **kwargs):
"""
Custom function to trick `accelerate.Accelerator` to get a state dict that will work
when training with LoRA & FSDP.
"""
print('RETURNING TEMP MODEL STATE DICT INTO ACCELERATORS SAVE MODEL')
return model_state

print('OVERWRITING get_state_dict WITH CUSTOM FN')
self.get_state_dict = _custom_get_state_dict
print('CALLING ACCELERATOR SAVE_MODEL')
self.save_model(model, *args, **kwargs)
print('RETURNING get_state_dict FUNCTION TO OLD')
self.get_state_dict = old_get_state_dict

print('DELETING TMP MODEL')
del tmp_model
empty_cache()



print('SAVED SUCCESSFULLY')
# from IPython import embed; embed()

35 changes: 27 additions & 8 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# pylint: disable=no-name-in-module
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
from tqdm import tqdm
from transformers import AutoModelForCausalLM, get_scheduler
from transformers import AutoModelForCausalLM, get_scheduler, LlamaForCausalLM, PreTrainedModel
import torch
import torch.distributed

Expand Down Expand Up @@ -211,9 +211,8 @@
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules,
)
model = prepare_peft_model(
model, peft_config, gradient_checkpointing=not args.is_granite
)
from peft import LoraModel
model = LoraModel(model, peft_config, "default")

elif not args.is_granite:
model.gradient_checkpointing_enable()
Expand All @@ -238,6 +237,15 @@
accelerator = setup_accelerator(args, model, grad_accum)
if args.distributed_training_framework == DistributedBackend.FSDP.value:
model = accelerator.prepare(model)
# print(model)
# print(f"model is instance of Pretrainemodel: {isinstance(model, PreTrainedModel)}")

# if accelerator.is_main_process:
# from IPython import embed; embed()

torch.distributed.barrier()


optimizer = setup_optimizer(args, model)

lr_scheduler = get_scheduler(
Expand Down Expand Up @@ -526,9 +534,11 @@
tokenizer = setup_tokenizer(args.model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE)
# device = torch.device("cuda", args.local_rank)

local_rank = int(os.environ['LOCAL_RANK'])

#### distributed init #####
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
args.local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
args.local_rank = local_rank
torch.distributed.init_process_group("nccl")
args.global_rank = torch.distributed.get_rank()
tensor = torch.ByteTensor([False]).cuda()
Expand All @@ -541,6 +551,8 @@
mock_len=args.mock_len,
)

# check this across all processes - in theory this shouldnt matter but doing it anyway cause its cool 😎
multipack_did_fail = torch.ByteTensor([False]).cuda()
try:
packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum(
num_gpus=torch.distributed.get_world_size(),
Expand All @@ -552,15 +564,22 @@
seed=args.seed,
)
args.sampler = "multipack"
except RuntimeError as e:
except (RuntimeError, ZeroDivisionError) as e:
if os.environ["LOCAL_RANK"] == "0":
print(f"\033[38;5;120m{e}\033[0m")

# fallback to grad accum = 1
# NOTE: packing max batch len will not be used
packing_max_batch_len = None
grad_accum = 1
args.sampler = "distributed"
multipack_did_fail[0] = 1

Check failure on line 575 in src/instructlab/training/main_ds.py

View workflow job for this annotation

GitHub Actions / pylint

E1137: 'multipack_did_fail' does not support item assignment (unsupported-assignment-operation)
finally:
torch.distributed.all_reduce(multipack_did_fail, op=torch.distributed.ReduceOp.MAX)
if multipack_did_fail[0]:

Check failure on line 578 in src/instructlab/training/main_ds.py

View workflow job for this annotation

GitHub Actions / pylint

E1136: Value 'multipack_did_fail' is unsubscriptable (unsubscriptable-object)
print('multipack was found to fail on at least one process, falling back to naive sampling')
args.sampler = "distributed"



args.samples_per_gpu = (
args.effective_batch_size // grad_accum // torch.distributed.get_world_size()
Expand Down
34 changes: 27 additions & 7 deletions src/instructlab/training/setup_accelerator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
# Standard
from copy import deepcopy

Check warning on line 2 in src/instructlab/training/setup_accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused deepcopy imported from copy (unused-import)
from functools import partial
import os

Check warning on line 4 in src/instructlab/training/setup_accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused import os (unused-import)
from typing import Callable, Tuple, Any, Union

Check warning on line 5 in src/instructlab/training/setup_accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused Callable imported from typing (unused-import)

Check warning on line 5 in src/instructlab/training/setup_accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused Tuple imported from typing (unused-import)

Check warning on line 5 in src/instructlab/training/setup_accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused Any imported from typing (unused-import)

Check warning on line 5 in src/instructlab/training/setup_accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused Union imported from typing (unused-import)

# Third Party
from accelerate import Accelerator
from torch.distributed.fsdp import ( # FullyShardedDataParallel as FSDP,
from accelerate import Accelerator, DistributedType

Check warning on line 8 in src/instructlab/training/setup_accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused Accelerator imported from accelerate (unused-import)

Check warning on line 8 in src/instructlab/training/setup_accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused DistributedType imported from accelerate (unused-import)
from peft.tuners.lora.model import LoraModel
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.fsdp import (
BackwardPrefetch,
MixedPrecision,
ShardingStrategy,
)
from torch import nn
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import torch
from torch import distributed as dist
from transformers import PreTrainedModel

# First Party
from instructlab.training.config import DeepSpeedOptions
from instructlab.training.utils import get_module_class_from_name, patch_target_module
from instructlab.training.internal import __SuperAccelerator




def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions):
Expand Down Expand Up @@ -58,22 +69,30 @@

block_name = model._no_split_modules[0]

fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=partial(
wrap_policy = None
if args.lora_r > 0:
from peft.utils.other import fsdp_auto_wrap_policy

Check warning on line 74 in src/instructlab/training/setup_accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

W0404: Reimport 'fsdp_auto_wrap_policy' (imported line 10) (reimported)
wrap_policy = fsdp_auto_wrap_policy(model)
else:
wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
get_module_class_from_name(model, block_name),
},
),
)

fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrap_policy,
limit_all_gathers=True,
mixed_precision_policy=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
backward_prefetch=BackwardPrefetch.BACKWARD_POST,
sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy],
cpu_offload=CPUOffload(args.cpu_offload_params_fsdp),
use_orig_params=False,
)
return fsdp_plugin

Expand Down Expand Up @@ -106,12 +125,13 @@
elif args.distributed_training_framework == "fsdp":
accel_args = {
"fsdp_plugin": get_fsdp_config(args, model),
'mixed_precision': "bf16",
}
else:
raise ValueError(
f"Unknown sharding framework: {args.distributed_training_framework}"
)
accelerator = Accelerator(
accelerator = __SuperAccelerator(
**accel_args,
)
accelerator.even_batches = False
Expand Down
Loading
Loading