Skip to content

Commit

Permalink
Rename the module
Browse files Browse the repository at this point in the history
  • Loading branch information
epicfilemcnulty committed May 12, 2024
1 parent b8e95c6 commit 45e4e61
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 23 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ requires = ["flit_core >=3.2,<4"]
build-backend = "flit_core.buildapi"

[project]
name = "lilush_llm_proxy"
name = "lilush_llm_backend"
authors = [{name = "Vladimir Zorin", email = "[email protected]"}]
license = {file = "LICENSE"}
classifiers = ["License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)"]
dynamic = ["version", "description"]
dependencies = [ 'torch', 'numpy', 'bltzr', 'peft', 'bottle' ]

[project.urls]
Home = "https://github.com/epicfilemcnulty/lilush-llm-proxy"
Home = "https://github.com/epicfilemcnulty/lilush-llm-backend"
2 changes: 1 addition & 1 deletion serve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
from lilush_llm_proxy import Serve
from lilush_llm_backend import Serve

parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default="127.0.0.1", required=False)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" Lilush LLM Proxy """

__version__ = "0.1.0"
__version__ = "0.1.3"

from .loader import *
from .generation import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,17 @@ def MambaQuery(query, sampler, model, tokenizer):

input_ids = torch.LongTensor(tokens).unsqueeze(0).cuda()
prompt_tokens = len(input_ids[0])

output_ids = model.generate(
input_ids=input_ids,
max_length=prompt_tokens + sampler['max_new_tokens'],
temperature=sampler['temperature'],
top_k=sampler['top_k'],
top_p=sampler['top_p'],
min_p=sampler['min_p'],
repetition_penalty = sampler['repetition_penalty'],
eos_token_ids=sc,
cg=True,
eos_token_ids=sc
do_sample=True,
repetition_penalty = sampler['repetition_penalty'],
)
gen_text = tokenizer.decode(output_ids[0], hide_special_tokens=sampler['hide_special_tokens'])
old_text = ""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, MambaForCausalLM
from bltzr import Tokenizer
from peft import PeftModel
from .mixin import GenerationMixin
Expand Down Expand Up @@ -54,10 +54,57 @@ def LoadTfModel(model_dir, context_length=None, lora_dir=None, trust_remote_code

return { "model": model, "tokenizer": tokenizer, "type": "tf" }

class CustomAutoModelForCausalLM(AutoModelForCausalLM, GenerationMixin):
pass
class CustomModelForCausalLM(MambaForCausalLM, GenerationMixin):
def _validate_model_kwargs(self, model_kwargs):
# Skip validation for unsupported arguments
supported_kwargs = [
"max_length",
"min_length",
"do_sample",
"early_stopping",
"num_beams",
"temperature",
"top_k",
"top_p",
"repetition_penalty",
"bad_words_ids",
"bos_token_id",
"pad_token_id",
"eos_token_id",
"length_penalty",
"no_repeat_ngram_size",
"encoder_no_repeat_ngram_size",
"num_return_sequences",
"max_time",
"max_new_tokens",
"decoder_start_token_id",
"use_cache",
"num_beam_groups",
"diversity_penalty",
"prefix_allowed_tokens_fn",
"logits_processor",
"renormalize_logits",
"stopping_criteria",
"constraints",
"output_attentions",
"output_hidden_states",
"output_scores",
"return_dict_in_generate",
"forced_bos_token_id",
"forced_eos_token_id",
"remove_invalid_values",
"synced_gpus",
"exponential_decay_length_penalty",
"suppress_tokens",
"begin_suppress_tokens",
"forced_decoder_ids",
]
for key in list(model_kwargs): # Making a copy of model_kwargs with `list` so we can remove elements from the original
if key not in supported_kwargs:
model_kwargs.pop(key)
super()._validate_model_kwargs(model_kwargs)

def LoadMambaModel(model_dir):
tokenizer = Tokenizer()
model = CustomAutoModelForCausalLM.from_pretrained(model_dir)
model = CustomModelForCausalLM.from_pretrained(model_dir).to('cuda')
return { "model": model, "tokenizer": tokenizer, "type": "mamba" }
63 changes: 51 additions & 12 deletions src/lilush_llm_proxy/mixin.py → src/lilush_llm_backend/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,31 @@ def reset(self, max_seqlen, max_batch_size):
self.lengths_per_sample.zero_()


# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf. Done in-place."""
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(indices_to_remove, float("-Inf"))


# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf. Done in-place."""
if top_p <= 0.0 or top_p >= 1.0:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits.masked_fill_(indices_to_remove, float("-inf"))

def modify_logits_for_min_p_filtering(logits, min_p):
"""Set the logits for none min_p values to -inf. Done in-place."""
if min_p <= 0.0 or min_p >= 1.0:
Expand All @@ -54,32 +79,47 @@ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_p
logits.scatter_(1, prev_output_tokens, score)
return logits

def sample(logits, greedy=False, min_p=0.0, temperature=1.0):
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if greedy:
if top_k == 1: # Greedy decoding
return logits.argmax(dim=-1)

logits_top = logits.clone()
if top_p > 0.0:
assert top_p <= 1.0, "top-p should be in (0, 1]."

if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1)
if temperature != 1.0:
logits_top /= temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[
torch.arange(indices.shape[0], device=indices.device),
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
]
if min_p > 0.0:
logits_top = logits.clone()
max_prob = logits_top[..., 0].item()
min_prob = max_prob * min_p
modify_logits_for_min_p_filtering(logits_top, min_prob)
if temperature != 1.0:
logits_top /= temperature
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)

if temperature != 1.0:
logits_top /= temperature

logits_top = logits / temperature if temperature != 1.0 else logits.clone()
modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)


@torch.inference_mode()
def decode(
input_ids,
model,
max_length,
greedy=True,
top_k=1,
top_p=0.0,
min_p=0.0,
temperature=1.0,
repetition_penalty=1.0,
Expand Down Expand Up @@ -148,7 +188,7 @@ def get_logits(input_ids, inference_params):
return logits[..., :vocab_size] if vocab_size is not None else logits

def sample_tokens(logits, inference_params):
token = sample(logits, greedy=greedy, min_p=min_p, temperature=temperature)
token = sample(logits, top_k=top_k, min_p=min_p, temperature=temperature)
# return rearrange(token, "b -> b 1")
return token.unsqueeze(1)

Expand Down Expand Up @@ -188,7 +228,7 @@ def should_stop(current_token, inference_params):
end.record()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
output_cls = GreedySearchDecoderOnlyOutput if greedy else SampleDecoderOnlyOutput
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))


Expand All @@ -209,13 +249,12 @@ def generate(
**kwargs,
):
output = decode(
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
input_ids, self, max_length, top_k=top_k, eos_token_ids=eos_token_ids, cg=cg, min_p = min_p, temperature=temperature, **kwargs
)
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences


@dataclass
class DecodingCGCache:
max_batch_size: int = 0
Expand Down
File renamed without changes.

0 comments on commit 45e4e61

Please sign in to comment.