Skip to content

Commit

Permalink
wip - saving logic
Browse files Browse the repository at this point in the history
  • Loading branch information
RobotSail committed Oct 14, 2024
1 parent b9c137d commit ff4d5df
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 16 deletions.
71 changes: 64 additions & 7 deletions src/instructlab/training/internal/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from peft import LoraModel
from torch import distributed as dist
from torch import nn
from torch.cuda import empty_cache
from transformers import PreTrainedModel
import torch

def wraps(module: nn.Module, wrapped_classes: Tuple[Any]) -> bool:
"""Checks if a module or its children are an instance of one of the provided classes.
Expand Down Expand Up @@ -51,20 +53,22 @@ def prepare(self, *args, **kwargs):
# Make sure this only happens once per object lifecycle - we call accelerator.prepare
# several times.
num_times_found = 0
using_lora = torch.ByteTensor([False]).cuda()
if self.distributed_type == DistributedType.FSDP and self.is_main_process and not self._cpu_model:
for arg in args:
if isinstance(arg, nn.Module) and wraps(arg, PreTrainedModel) and wraps(arg, LoraModel):
self._is_lora = True
using_lora[0] = True

Check failure on line 60 in src/instructlab/training/internal/accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

E1137: 'using_lora' does not support item assignment (unsupported-assignment-operation)
num_times_found += 1
# from IPython import embed; embed()
# cpu model setter logic will handle validation - but this may be a stupid idea and
# we should instead handle it here
# self.cpu_model = arg


self.cpu_model = arg
break
print(f'number of times we found a lora pretrained arg: {num_times_found}')
dist.breakpoint()
dist.barrier()
dist.all_reduce(using_lora, op=dist.ReduceOp.MAX)
if using_lora[0]:

Check failure on line 70 in src/instructlab/training/internal/accelerator.py

View workflow job for this annotation

GitHub Actions / pylint

E1136: Value 'using_lora' is unsubscriptable (unsubscriptable-object)
self._is_lora = True
return super().prepare(*args, **kwargs)

@property
Expand All @@ -91,4 +95,57 @@ def cpu_model(self, model: nn.Module) -> nn.Module | None:
# As long as we correctly prepare the model through Accelerate, we should not hit this.
raise RuntimeError('Copying a model from the GPU to the CPU is not supported.')

self._cpu_model = deepcopy(model)
self._cpu_model = deepcopy(model)

def save_lora_fsdp(self, model: nn.Module, *args, **kwargs) -> None:
"""Extension of the `accelerate.Accelerator.save_model` method.
This provides the ability to save a model in SafeTensors format when training a LoRA with FSDP.
Args:
model (nn.Module): The accelerator-wrapped model to save.
"""


if self.distributed_type != DistributedType.FSDP:
raise RuntimeError('`__SuperAccelerator.save_fsdp_lora` was called when FSDP was not being used.')
if not self._is_lora:
raise RuntimeError('`__SuperAccelerator.save_fsdp_lora` was called but was not configured to use LoRA')

print('GETTING OLD MODEL STATE DICT')
model_state = self.get_state_dict(model, unwrap=True)
print('COPYING CPU MODEL')
if self.is_main_process:
tmp_model: LoraModel = deepcopy(self.cpu_model)
print('LOADING STATE DICT INTO TEMP MODEL')
tmp_model.load_state_dict(model_state)
print('MERGING & UNLOADING TEMP MODEL')
tmp_model.merge_and_unload(True)
print('GETTING STATE DICT FROM TEMP MODEL')
model_state = tmp_model.state_dict()

old_get_state_dict = self.get_state_dict
def _custom_get_state_dict(ref: Accelerator, *args, **kwargs):
"""
Custom function to trick `accelerate.Accelerator` to get a state dict that will work
when training with LoRA & FSDP.
"""
print('RETURNING TEMP MODEL STATE DICT INTO ACCELERATORS SAVE MODEL')
return model_state

print('OVERWRITING get_state_dict WITH CUSTOM FN')
self.get_state_dict = _custom_get_state_dict
print('CALLING ACCELERATOR SAVE_MODEL')
self.save_model(model, *args, **kwargs)
print('RETURNING get_state_dict FUNCTION TO OLD')
self.get_state_dict = old_get_state_dict

print('DELETING TMP MODEL')
del tmp_model
empty_cache()



print('SAVED SUCCESSFULLY')
# from IPython import embed; embed()

25 changes: 16 additions & 9 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,16 +662,23 @@ def save_hf_format_accelerate_lora(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
accelerator: __SuperAccelerator,
# config_file: Path
):
print('SAVING LORA')
if accelerator.is_main_process:
accelerator.save_lora_fsdp(
model,
save_directory=args.output_dir,
max_shard_size="5GB",
safe_serialization=True
)
tokenizer.save_pretrained(args.output_dir)
# if accelerator.is_main_process:
output_dir = Path(args.output_dir)
if not output_dir.exists():
output_dir.mkdir(parents=True, exist_ok=True)
accelerator.save_lora_fsdp(
model,
save_directory=args.output_dir,
max_shard_size="5GB",
safe_serialization=True
)
config_file_out = Path(f'{args.output_dir}/config.json')

Check warning on line 679 in src/instructlab/training/utils.py

View workflow job for this annotation

GitHub Actions / pylint

C0303: Trailing whitespace (trailing-whitespace)
model.config.to_json_file(config_file_out)
tokenizer.save_pretrained(args.output_dir)
dist.barrier()

def save_hf_format_accelerate(
Expand Down Expand Up @@ -700,7 +707,7 @@ def save_hf_format_accelerate(
output_config_file = output_dir / CONFIG_NAME

if is_lora and accelerator.distributed_type == DistributedType.FSDP:
save_hf_format_accelerate_lora(args, model, tokenizer, accelerator)
save_hf_format_accelerate_lora(args, model, tokenizer, accelerator )
return

if is_lora:
Expand Down

0 comments on commit ff4d5df

Please sign in to comment.