Skip to content

Commit

Permalink
Add Qwen2.5 vLLM generator (based on LlamaGenerator), fix batch 1 iss…
Browse files Browse the repository at this point in the history
…ue with generator's decode forward (#17422)
  • Loading branch information
skhorasganiTT authored and hschoi4448 committed Feb 20, 2025
1 parent d5b5e1a commit e7f97bc
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 29 deletions.
95 changes: 69 additions & 26 deletions models/demos/llama3/tt/generator_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import PIL
from llama_models.llama3.api.chat_format import create_vision_mask
from llama_models.llama3.api.tokenizer import Tokenizer
import ttnn

from models.demos.llama3.tt.generator import LlamaGenerator
Expand All @@ -21,6 +20,41 @@
from vllm.model_executor.models.mllama import MLLAMA_IMAGE_TOKEN_ID, MLLAMA_IMAGE_TOKEN


def initialize_vllm_text_transformer(
hf_config,
mesh_device,
max_batch_size,
max_seq_len,
n_layers=None,
dtype=ttnn.bfloat8_b,
optimizations=LlamaOptimizations.performance,
):
# Load model args, weights
model_args = TtModelArgs(
mesh_device,
instruct=("Instruct" in hf_config._name_or_path),
max_batch_size=max_batch_size,
optimizations=optimizations,
max_seq_len=max_seq_len,
)
assert model_args.model_name.replace("-", "") in hf_config._name_or_path.replace(
"-", ""
), f"The model specified in vLLM ({hf_config._name_or_path}) does not match the model name ({model_args.model_name}) with model weights ({model_args.DEFAULT_CKPT_DIR})."
if n_layers is not None:
model_args.n_layers = n_layers
state_dict = model_args.load_state_dict()

tt_model = TtTransformer(
args=model_args,
mesh_device=mesh_device,
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
use_paged_kv_cache=True,
)
return tt_model, model_args


def input_processor_for_mllama(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]):
"""
Based on vllm.model_executor.models.mllama.py::input_processor_for_mllama().
Expand Down Expand Up @@ -140,33 +174,42 @@ def __init__(self, *args, **kwargs):

@classmethod
def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, n_layers=None):
instruct_mode = "Instruct" in hf_config._name_or_path
max_seq_len = 131072 # TODO: modify this for different models/devices
optimizations = LlamaOptimizations.performance # TODO: maybe change to accuracy
dtype = ttnn.bfloat8_b

# Load model args, weights
model_args = TtModelArgs(
tt_model, model_args = initialize_vllm_text_transformer(
hf_config,
mesh_device,
instruct=instruct_mode,
max_batch_size=max_batch_size,
optimizations=optimizations,
max_seq_len=max_seq_len,
max_batch_size,
max_seq_len=131072,
n_layers=n_layers,
dtype=ttnn.bfloat8_b,
optimizations=LlamaOptimizations.performance,
)
assert (
model_args.model_name in hf_config._name_or_path
), f"The model specified in vLLM ({hf_config._name_or_path}) does not match the model weights ({model_args.DEFAULT_CKPT_DIR})."
if n_layers is not None:
model_args.n_layers = n_layers
state_dict = model_args.load_state_dict()

tt_model = TtTransformer(
args=model_args,
mesh_device=mesh_device,
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
use_paged_kv_cache=True,
return cls(tt_model, model_args, mesh_device)

@property
def cache_path(self):
return self.model_args.model_cache_path

def prefill_forward(self, *args, **kwargs):
return super().prefill_forward_text(*args, **kwargs)

def decode_forward(self, *args, **kwargs):
return super().decode_forward_text(*args, **kwargs)


class TtQwen2ForCausalLM(LlamaGenerator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, n_layers=None):
tt_model, model_args = initialize_vllm_text_transformer(
hf_config,
mesh_device,
max_batch_size,
max_seq_len=131072,
n_layers=n_layers,
dtype=ttnn.bfloat8_b,
optimizations=LlamaOptimizations.performance,
)
return cls(tt_model, model_args, mesh_device)

Expand Down
8 changes: 5 additions & 3 deletions models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,10 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None):
assert current_pos.shape[0] == B, "Batch size mismatch"
assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size"

# Necessary padding to be full tile sized when on device
tokens = torch.nn.functional.pad(tokens.view(-1), (0, 32 - len(tokens)), "constant", 0)
tokens = ttnn.from_torch(
tokens.view(-1),
tokens,
device=None,
dtype=ttnn.uint32,
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
Expand Down Expand Up @@ -254,10 +256,10 @@ def process_output_decode(self, tt_out, B, S=1):
num_links=2,
cluster_axis=0,
mesh_device=self.mesh_device,
topology=ttnn.Topology.Linear,
topology=self.args.ccl_topology(),
)
else:
tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=self.args.ccl_topology())
tt_out = ttnn.untilize(tt_out, use_multicore=True)
if self.args.num_devices > 1:
tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float()
Expand Down

0 comments on commit e7f97bc

Please sign in to comment.