Skip to content

Commit

Permalink
another rebase and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat committed Jan 24, 2025
1 parent 167c0f2 commit d50e6c7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def forward(
output = torch.ones_like(query)
return output

assert layer._k_scale == 1.0 and layer._v_scale == 1.0
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.model_runner.initialize_kv_cache(kv_cache_config)

def compile_or_warm_up_model(self) -> None:
assert self.model_runner is not None

# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
Expand All @@ -182,8 +184,8 @@ def compile_or_warm_up_model(self) -> None:
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
self.model_runner.dummy_run(None, size)

if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def _prepare_prefill_inputs(
num_decode_tokens=0,
slot_mapping=slot_mapping.to(self.device),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=None,
context_lens=None,
effective_query_lens=None,
Expand Down Expand Up @@ -328,6 +329,7 @@ def _prepare_decode_inputs(self) -> DecodeInputData:
num_decode_tokens=padded_batch_size,
slot_mapping=slot_mapping.to(self.device),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_table.to(self.device),
context_lens=context_lens.to(self.device),
effective_query_lens=None,
Expand Down Expand Up @@ -517,6 +519,7 @@ def dummy_run(
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=None,
context_lens=None,
effective_query_lens=None,
Expand All @@ -540,6 +543,7 @@ def dummy_run(
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
Expand Down Expand Up @@ -568,6 +572,7 @@ def dummy_run(
num_decode_tokens=num_tokens * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_tables,
context_lens=context_lens,
)
Expand Down

0 comments on commit d50e6c7

Please sign in to comment.