diff --git a/src/instructlab/training/chat_templates/ibm_generic_tmpl.py b/src/instructlab/training/chat_templates/ibm_generic_tmpl.py
new file mode 100644
index 00000000..87bfdb0a
--- /dev/null
+++ b/src/instructlab/training/chat_templates/ibm_generic_tmpl.py
@@ -0,0 +1,24 @@
+# First Party
+from instructlab.training.tokenizer_utils import SpecialTokens
+
+SPECIAL_TOKENS = SpecialTokens(
+ system="<|system|>",
+ user="<|user|>",
+ assistant="<|assistant|>",
+ eos="<|endoftext|>",
+ pad="<|pad|>",
+)
+
+CHAT_TEMPLATE = (
+ "{% for message in messages %}"
+ "{% if message['role'] == 'pretraining' %}"
+ "{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}"
+ "{% elif message['role'] == 'system' %}"
+ "{{'<|system|>'+ '\n' + message['content'] + '\n'}}"
+ "{% elif message['role'] == 'user' %}"
+ "{{'<|user|>' + '\n' + message['content'] + '\n'}}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}"
+ "{% endif %}"
+ "{% endfor %}"
+)
diff --git a/src/instructlab/training/chat_templates/mistral_tmpl.py b/src/instructlab/training/chat_templates/mistral_tmpl.py
new file mode 100644
index 00000000..965823f2
--- /dev/null
+++ b/src/instructlab/training/chat_templates/mistral_tmpl.py
@@ -0,0 +1,22 @@
+# First Party
+from instructlab.training.tokenizer_utils import SpecialTokens
+
+SPECIAL_TOKENS = SpecialTokens(
+ bos="",
+ eos="",
+ user="[INST]",
+ assistant="[/INST]",
+)
+
+CHAT_TEMPLATE = (
+ "{{ '' }}"
+ "{% for message in messages %}"
+ "{% if message['role'] == 'pretraining' %}"
+ "{{ message['content'] + '' }}"
+ "{% elif message['role'] == 'user' %}"
+ "{{ '[INST] ' + message['content'] + ' [/INST]' }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ message['content'] + ''}}"
+ "{% endif %}"
+ "{% endfor %}"
+)
diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py
index c8733a45..83c7a1f8 100644
--- a/src/instructlab/training/config.py
+++ b/src/instructlab/training/config.py
@@ -4,6 +4,7 @@
# Standard
from enum import Enum
+import os
# Third Party
from pydantic import BaseModel, ConfigDict, Field
@@ -42,6 +43,7 @@ class DataProcessArgs(BaseModel):
data_output_path: str
max_seq_len: int # defines the max sequence length of a sample
model_path: str # either a HF model name or path to HF model
+ chat_tmpl_path: str
# disable the protected namespace for the model_config field
model_config = ConfigDict(protected_namespaces=())
@@ -100,6 +102,11 @@ class TrainingArgs(BaseModel):
# Either the name of a HuggingFace model or a path to a model saved in HuggingFace format.
model_path: str
+ # Specify the chat template / special tokens for training (default is ibm-generic template/tokens)
+ chat_tmpl_path: str = os.path.join(
+ os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
+ )
+
# this field specifies the filepath to the training dataset before processing
data_path: str
ckpt_output_dir: str
diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py
index f8f43535..9301d185 100644
--- a/src/instructlab/training/data_process.py
+++ b/src/instructlab/training/data_process.py
@@ -2,6 +2,7 @@
from pathlib import Path
from typing import List
import logging
+import os
# Third Party
from datasets import load_dataset
@@ -10,12 +11,8 @@
# First Party
from instructlab.training.config import DataProcessArgs
-from instructlab.training.tokenizer_utils import (
- SPECIAL_TOKENS,
- get_sp_token,
- setup_tokenizer,
-)
-from instructlab.training.utils import log_rank_0, setup_logger
+from instructlab.training.tokenizer_utils import get_sp_token, setup_tokenizer
+from instructlab.training.utils import log_rank_0, retrieve_chat_template, setup_logger
def check_valid_sample(
@@ -37,18 +34,6 @@ def check_valid_sample(
if not any(token in whole_sentence_tk for token in special_tokens):
return True
- # first token should be system_token
- if whole_sentence_tk[0] != system_tk:
- print("\033[91mfirst token is not a system_token\033[0m")
- log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True)
- return False
-
- # check there's only one system_token
- if (np.array(whole_sentence_tk) == system_tk).sum() != 1:
- print("\033[91mthere are more than one system_token\033[0m")
- log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True)
- return False
-
whole_sentence_tk = np.array(whole_sentence_tk)
user_token_index = (whole_sentence_tk == user_tk).nonzero()[0]
assistant_token_index = (whole_sentence_tk == assistant_tk).nonzero()[0]
@@ -121,7 +106,11 @@ def unmask_only_assistant_responses(
whole_sentence = chosen_token["input_ids"][:sentence_legth].clone()
# pre-training mode
- if system_tk not in whole_sentence:
+ if not (
+ system_tk in whole_sentence
+ or user_token in whole_sentence
+ or assist_token in whole_sentence
+ ):
return labels
labels[:sentence_legth] = -100
@@ -204,11 +193,15 @@ def remove_pretrain_system_messages(example: dict):
def main(args: DataProcessArgs):
- tokenizer = setup_tokenizer(args.model_path)
+ CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path)
+ tokenizer = setup_tokenizer(args.model_path, SPECIAL_TOKENS, CHAT_TEMPLATE)
eos_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.eos)
pad_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.pad)
- system_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.system)
+ if SPECIAL_TOKENS.system:
+ system_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.system)
+ else:
+ system_tk = None
user_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.user)
assistant_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.assistant)
log_rank_0(
@@ -309,6 +302,14 @@ def main(args: DataProcessArgs):
parser.add_argument(
"--model_name_or_path", type=str, required=True, help="Model name or path"
)
+ parser.add_argument(
+ "--chat-tmpl-path",
+ type=str,
+ default=os.path.join(
+ os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
+ ),
+ help="Path to desired chat template and special tokens, defaults to IBM generic.",
+ )
args = parser.parse_args()
setup_logger(args.logging_level)
data_process_args = DataProcessArgs(
@@ -316,6 +317,7 @@ def main(args: DataProcessArgs):
data_path=args.data_path,
max_seq_len=args.max_seq_len,
model_path=args.model_name_or_path,
+ chat_tmpl_path=args.chat_tmpl_path,
)
main(data_process_args)
diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py
index eeb0c077..1eb21a68 100644
--- a/src/instructlab/training/main_ds.py
+++ b/src/instructlab/training/main_ds.py
@@ -37,6 +37,7 @@
patch_target_module,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
+ retrieve_chat_template,
save_hf_format_ds,
save_model_ds_native,
set_random_seed,
@@ -438,7 +439,8 @@ def main(args):
print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m")
setup_logger(args.log_level)
- tokenizer = setup_tokenizer(args.model_name_or_path)
+ CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path)
+ tokenizer = setup_tokenizer(args.model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE)
# device = torch.device("cuda", args.local_rank)
#### distributed init #####
@@ -522,6 +524,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
model_path=train_args.model_path,
data_path=train_args.data_path,
max_seq_len=train_args.max_seq_len,
+ chat_tmpl_path=train_args.chat_tmpl_path,
)
)
@@ -546,6 +549,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
f"--log_level=INFO",
f"--max_batch_len={train_args.max_batch_len}",
f"--seed={train_args.random_seed}",
+ f"--chat-tmpl-path={train_args.chat_tmpl_path}",
]
if train_args.mock_data:
@@ -644,6 +648,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
help="Offload optimizer to CPU when using DeepSpeed. This configures it to use ZeRO stage 2.",
)
parser.add_argument("--NEFTune_alpha", type=float, default=None)
+ parser.add_argument(
+ "--chat-tmpl-path",
+ type=str,
+ default=os.path.join(
+ os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
+ ),
+ )
args = parser.parse_args()
set_random_seed(args.seed)
main(args)
diff --git a/src/instructlab/training/tokenizer_utils.py b/src/instructlab/training/tokenizer_utils.py
index 7eff0e69..5c789441 100644
--- a/src/instructlab/training/tokenizer_utils.py
+++ b/src/instructlab/training/tokenizer_utils.py
@@ -10,46 +10,34 @@
@dataclass
class SpecialTokens:
- system: str = field(default="<|system|>")
+ system: str = field(default=None)
user: str = field(default="<|user|>")
assistant: str = field(default="<|assistant|>")
eos: str = field(default="<|endoftext|>")
- pad: str = field(default="<|pad|>")
+ pad: str = field(default=None)
+ bos: str = field(default="<|begginingoftext|>")
-SPECIAL_TOKENS = SpecialTokens()
-
-CHAT_TEMPLATE = (
- "{% for message in messages %}"
- "{% if message['role'] == 'pretraining' %}"
- "{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}"
- "{% elif message['role'] == 'system' %}"
- "{{'<|system|>'+ '\n' + message['content'] + '\n'}}"
- "{% elif message['role'] == 'user' %}"
- "{{'<|user|>' + '\n' + message['content'] + '\n'}}"
- "{% elif message['role'] == 'assistant' %}"
- "{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}"
- "{% endif %}"
- "{% endfor %}"
-)
-
-
-def setup_tokenizer(
- model_name_or_path, SPECIAL_TOKENS=SPECIAL_TOKENS, CHAT_TEMPLATE=CHAT_TEMPLATE
-):
+def setup_tokenizer(model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, fast_tokenizer=True)
- tokenizer.add_special_tokens(
- {"eos_token": SPECIAL_TOKENS.eos, "pad_token": SPECIAL_TOKENS.pad}
- )
+
+ if not SPECIAL_TOKENS.pad:
+ SPECIAL_TOKENS.pad = SPECIAL_TOKENS.eos
tokenizer.add_special_tokens(
{
- "additional_special_tokens": [
- SPECIAL_TOKENS.system,
- SPECIAL_TOKENS.user,
- SPECIAL_TOKENS.assistant,
- ]
+ "bos_token": SPECIAL_TOKENS.bos,
+ "eos_token": SPECIAL_TOKENS.eos,
+ "pad_token": SPECIAL_TOKENS.pad,
}
)
+
+ if SPECIAL_TOKENS.system:
+ add_token_list = [SPECIAL_TOKENS.system]
+ else:
+ add_token_list = []
+ add_token_list.extend([SPECIAL_TOKENS.user, SPECIAL_TOKENS.assistant])
+
+ tokenizer.add_special_tokens({"additional_special_tokens": add_token_list})
if getattr(tokenizer, "add_bos_token", False) or getattr(
tokenizer, "add_eos_token", False
):
diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py
index 6feaa548..5eccb5dc 100644
--- a/src/instructlab/training/utils.py
+++ b/src/instructlab/training/utils.py
@@ -24,6 +24,19 @@
import torch.nn.functional as F
+def retrieve_chat_template(chat_tmpl_path):
+ try:
+ spec = importlib.util.spec_from_file_location("spcl_chat_tmpl", chat_tmpl_path)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules["spcl_chat_tmpl"] = module
+ spec.loader.exec_module(module)
+ SPECIAL_TOKENS = module.SPECIAL_TOKENS
+ CHAT_TEMPLATE = module.CHAT_TEMPLATE
+ except:
+ sys.exit(f"Invalid chat template path: {chat_tmpl_path}")
+ return CHAT_TEMPLATE, SPECIAL_TOKENS
+
+
def add_noisy_embeddings(model, noise_alpha=None):
if not noise_alpha:
return model