From 2c5f15443ddc4e3159c781dc174af48c5b9ea119 Mon Sep 17 00:00:00 2001 From: Salar Hosseini <159165450+skhorasganiTT@users.noreply.github.com> Date: Fri, 31 Jan 2025 14:36:01 -0500 Subject: [PATCH] [Llama3.2-11b-vision] Add max_cross_attn_tokens property to vLLM generator class (#17401) --- models/demos/llama3/tt/generator_vllm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/models/demos/llama3/tt/generator_vllm.py b/models/demos/llama3/tt/generator_vllm.py index cff4b51b4402..846e0cef34f0 100644 --- a/models/demos/llama3/tt/generator_vllm.py +++ b/models/demos/llama3/tt/generator_vllm.py @@ -130,6 +130,10 @@ def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size): def cache_path(self): return self.model_args.model_cache_path + @property + def max_cross_attn_tokens(self): + return self.model_args.vision_max_num_chunks * nearest_32(self.model_args.vision_chunk_ntok) + def prefill_forward( self, tokens: torch.Tensor,