Skip to content

Commit

Permalink
Merge pull request #30 from rjojjr/update-docs
Browse files Browse the repository at this point in the history
Add stop tokens
  • Loading branch information
rjojjr authored Aug 11, 2024
2 parents 51ccd5a + 6c8d5e6 commit 0ece83a
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/main/base/llm_base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def _add_agent_tokens(tokenizer, model):
agent_tokens = ["Thought:", "Action:", "Action Input:", "Observation:", "Final Answer:"]
agent_tokens = ["\nThought:", "\nAction:", "\nAction Input:", "\nObservation:"]
agent_tokens = set(agent_tokens) - set(tokenizer.vocab.keys())
tokenizer.add_tokens(list(agent_tokens))
if model is not None:
Expand Down
6 changes: 4 additions & 2 deletions src/main/llama/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def merge(arguments: MergeArguments) -> None:
torch_dtype=dtype
)

tokenizer = AutoTokenizer.from_pretrained(arguments.base_model)
lora_dir = f"{arguments.output_dir}/checkpoints/{arguments.new_model}/adapter"

tokenizer = AutoTokenizer.from_pretrained(lora_dir)
if arguments.padding_side is not None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = arguments.padding_side
Expand Down Expand Up @@ -53,7 +55,7 @@ def push(arguments: PushArguments) -> None:


def fine_tune(arguments: TuneArguments) -> None:
tokenizer = AutoTokenizer.from_pretrained(arguments.base_model, add_eos_token=True, add_bos_token=True)
tokenizer = AutoTokenizer.from_pretrained(arguments.base_model)
if arguments.padding_side is not None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = arguments.padding_side
Expand Down
4 changes: 2 additions & 2 deletions src/main/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os

# TODO - Automate this
version = '1.4.3'
version = '1.4.4'

# TODO - Change this once support for more LLMs is added
title = f'Llama AI LLM LoRA Torch Text Fine-Tuner v{version}'
Expand Down Expand Up @@ -38,7 +38,7 @@ def main() -> None:
print(f'Using fp32 CPU Offload: {str(args.fp32_cpu_offload)}')
print()
print(f"Serving {args.serve_model} on port {args.serve_port}")
factory = llm_executor_factory(LlmExecutorFactoryArguments(model=args.serve_model, use_4bit=args.use_4bit, use_8bit=args.use_8bit, is_fp16=args.use_fp_16, is_bf16=args.use_bf_16))
factory = llm_executor_factory(LlmExecutorFactoryArguments(model=args.serve_model, use_4bit=args.use_4bit, use_8bit=args.use_8bit, is_fp16=args.use_fp_16, is_bf16=args.use_bf_16, padding_side=args.padding_side))
server = OpenAiLlmServer(factory())
server.start_server(ServerArguments(port=args.serve_port, debug=args.debug))
# TODO - cleaner exit
Expand Down
13 changes: 9 additions & 4 deletions src/main/serve/llm_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable
from arguments.arguments import LlmExecutorFactoryArguments
from transformers import AutoTokenizer, LlamaForCausalLM
from transformers import AutoTokenizer, LlamaForCausalLM, StopStringCriteria, StoppingCriteriaList
from utils.torch_utils import get_bnb_config_and_dtype
from exception.exceptions import TunerException
import torch
Expand All @@ -26,16 +26,21 @@ def __init__(self, model, tokenizer, padding_side: str | None):
if padding_side is not None:
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = tokenizer.pad_token_id
tokenizer.padding_side = padding_side

self._model = model
self._tokenizer = tokenizer

# TODO - FIXME - multiple calls results in GPU memory overload(may be caused bnb?)
# TODO - Stop sequences
def completion(self, input: str, max_tokens: int = 150, temperature: float = 1, attempt: int = 1):
def completion(self, input: str, max_tokens: int = 150, temperature: float = 1, attempt: int = 1, stops: list | None = None):
if stops is None:
stops = []
try:
stopping_criteria = StoppingCriteriaList([StopStringCriteria(stop_strings=stops, tokenizer=self._tokenizer)])
model_inputs = self._tokenizer([input], padding=True if self._padding_side is not None else False, return_tensors="pt").to("cuda")
input_length = model_inputs.input_ids.shape[1]
generated_ids = self._model.generate(**model_inputs, max_new_tokens=max_tokens, do_sample=True, temperature=temperature)
generated_ids = self._model.generate(**model_inputs, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, stopping_criteria=stopping_criteria)
response = self._tokenizer.batch_decode(generated_ids[:, input_length:], skip_special_tokens=True)[0]
# TODO - FIXME - big hack to stop OOM
gc.collect()
Expand Down Expand Up @@ -67,6 +72,6 @@ def llm_executor_factory(arguments: LlmExecutorFactoryArguments) -> Callable[[],
torch_dtype="auto"
# TODO - investigate if this is effective
# attn_implementation="flash_attention_2"
), AutoTokenizer.from_pretrained(arguments.model, padding_side=arguments.padding_side))
), AutoTokenizer.from_pretrained(arguments.model), padding_side=arguments.padding_side)


5 changes: 2 additions & 3 deletions src/main/serve/routers/open_ai_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def build_routes(app: Flask, llm: LlmExecutor) -> None:
def _construct_chat_prompt(body: dict) -> str:
prompt = ""
for msg in body['messages']:
# TODO - Probably should replace `\n` with stop sequence(?)
prompt = f"{prompt}{msg['role']}: {msg['content']}\n"
return prompt

Expand All @@ -24,7 +23,7 @@ def chat_completions_endpoint():

prompt = _construct_chat_prompt(body)

completion = llm.completion(prompt, int(body['max_tokens']), parse_temp(float(body['temperature'])))
completion = llm.completion(prompt, int(body['max_tokens']), parse_temp(float(body['temperature'])), stops=body['stop'])
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(completion))
chat_response = {
Expand Down Expand Up @@ -53,7 +52,7 @@ def chat_completions_endpoint():
@app.route("/v1/completions", methods=['POST'])
def completions_endpoint():
body = request.get_json(force=True)
completion = llm.completion(body['prompt'], int(body['max_tokens']), parse_temp(float(body['temperature'])))
completion = llm.completion(body['prompt'], int(body['max_tokens']), parse_temp(float(body['temperature'])), stops=body['stop'])
prompt_tokens = len(encoding.encode(body['prompt']))
completion_tokens = len(encoding.encode(completion))

Expand Down

0 comments on commit 0ece83a

Please sign in to comment.