Skip to content

Commit

Permalink
Merge pull request #29 from rjojjr/update-docs
Browse files Browse the repository at this point in the history
Fix agent model merge
  • Loading branch information
rjojjr authored Aug 11, 2024
2 parents ffd0f1a + 2b1cc68 commit 51ccd5a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
16 changes: 9 additions & 7 deletions src/main/base/llm_base_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from arguments.arguments import TuneArguments, MergeArguments, PushArguments
from datasets import load_dataset
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, AutoPeftModelForCausalLM, PeftModel
from trl import SFTTrainer, SFTConfig, setup_chat_format
from transformers.trainer_utils import get_last_checkpoint
from utils.model_utils import get_all_layers, get_all_linear_layers
Expand All @@ -12,11 +12,11 @@


def _add_agent_tokens(tokenizer, model):
agent_tokens = ["Thought:", "Action:", "Input:", "Observation:", "Answer:", "Action\sInput:", "Final\sAnswer:"]

agent_tokens = ["Thought:", "Action:", "Action Input:", "Observation:", "Final Answer:"]
agent_tokens = set(agent_tokens) - set(tokenizer.vocab.keys())
tokenizer.add_tokens(list(agent_tokens))
model.resize_token_embeddings(len(tokenizer))
if model is not None:
model.resize_token_embeddings(len(tokenizer))


# TODO - Tune/extract an embeddings only model
Expand Down Expand Up @@ -135,10 +135,12 @@ def merge_base(arguments: MergeArguments, tokenizer, base_model, bnb_config) ->
lora_dir = f"{arguments.output_dir}/checkpoints/{arguments.new_model}/adapter"
model_dir = f'{arguments.output_dir}/{arguments.new_model}'
print(f"merging {arguments.base_model} with LoRA into {arguments.new_model}")
print('')

config = PeftConfig.from_pretrained(lora_dir)
model = PeftModel.from_pretrained(base_model, lora_dir, quantization_config=bnb_config, config=config)
if arguments.use_agent_tokens:
model = AutoPeftModelForCausalLM.from_pretrained(lora_dir)
else:
model = PeftModel.from_pretrained(base_model, lora_dir, quantization_config=bnb_config)

model = model.merge_and_unload(progressbar=True)
print('')

Expand Down
2 changes: 1 addition & 1 deletion src/main/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os

# TODO - Automate this
version = '1.4.2'
version = '1.4.3'

# TODO - Change this once support for more LLMs is added
title = f'Llama AI LLM LoRA Torch Text Fine-Tuner v{version}'
Expand Down

0 comments on commit 51ccd5a

Please sign in to comment.