Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Dolomite Support and Bringing HF Padding-Free into Performance Parity #312

Merged
merged 4 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ numba
numpy>=1.23.5,<2.0.0 ; python_version == '3.10'
numpy>=1.26.4,<2.0.0 ; python_version != '3.10'
rich
instructlab-dolomite>=0.1.1
instructlab-dolomite>=0.2.0
trl>=0.9.4
peft
pydantic>=2.7.0
Expand Down
5 changes: 5 additions & 0 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import deepcopy
from pathlib import Path
import argparse
import json
import math
import os
import re
Expand Down Expand Up @@ -528,6 +529,10 @@ def main(args):
tokenizer = setup_tokenizer(args.model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE)
# device = torch.device("cuda", args.local_rank)

with open(Path(args.model_name_or_path) / "config.json") as conf_json:
model_conf = json.load(conf_json)
args.model_type = model_conf["model_type"]

#### distributed init #####
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
args.local_rank = int(os.environ["LOCAL_RANK"])
Expand Down
39 changes: 21 additions & 18 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Any, List, Optional
import importlib
import inspect
import json
import logging
import os
import random
Expand Down Expand Up @@ -62,17 +61,10 @@ def check_valid_train_args(train_args: TrainingArgs):
f"Provided path to model does not exist. Please make sure that you've passed a valid model and that it has appropriate permissions: {train_args.model_path}"
)

if train_args.use_dolomite:
with open(Path(train_args.model_path) / "config.json") as conf_json:
model_conf = json.load(conf_json)
if model_conf["model_type"] == "granite":
raise RuntimeError(
"Converting Granite models to Dolomite format is currently unsupported."
)
if train_args.disable_flash_attn:
raise RuntimeError(
"ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported"
)
if train_args.use_dolomite and train_args.disable_flash_attn:
raise RuntimeError(
"ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported"
)

if train_args.is_padding_free:
print(
Expand Down Expand Up @@ -229,7 +221,7 @@ def pad_collate_fn(batch):

input_ids.extend(item["input_ids"].tolist())
labels.extend(item["labels"].tolist())
position_ids.extend(range(total_len, total_len + item_len))
position_ids.extend(range(item_len))

total_len += item_len
num_loss_counted_tokens += (item["labels"] != -100).sum().item()
Expand Down Expand Up @@ -802,10 +794,21 @@ def _get_state_dict_patched(model, unwrap=False):

output_dir.mkdir(parents=True, exist_ok=True)
if not model.module.config.architectures and convert_dolomite:
model.module.config.architectures = ["LlamaForCausalLM"]
warnings.warn(
f"Adding architectures to ckpt: {model.module.config.architectures}",
)
arch_added = False
if args.model_type == "llama":
model.module.config.architectures = ["LlamaForCausalLM"]
arch_added = True
elif args.model_type == "granite":
model.module.config.architectures = ["GraniteForCausalLM"]
arch_added = True
if arch_added:
warnings.warn(
f"Adding architectures to ckpt: {model.module.config.architectures}",
)
else:
warnings.warn(
f"Converting from dolomite, but no architecture field added to config.json",
)
model.module.config.to_json_file(output_config_file)
tokenizer.save_pretrained(output_dir)

Expand Down Expand Up @@ -834,7 +837,7 @@ def _get_state_dict_patched(model, unwrap=False):
export_to_huggingface(
pretrained_model_name_or_path=tmpdir.name,
save_path=final_output_dir,
model_type="llama",
model_type=args.model_type,
)
tmpdir.cleanup()

Expand Down