Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [V1] TPU support #11936

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

[WIP] [V1] TPU support #11936

wants to merge 1 commit into from

Conversation

alexm-redhat
Copy link
Collaborator

@alexm-redhat alexm-redhat commented Jan 10, 2025

This PR is a rebase and modification of @robertgshaw2-redhat original PR for TPU support in vLLM V1 from 1.5 months ago #10241

Currently, TPU attention kernel has no support for mixing prefills and decodes in the same scheduler iteration. As a result, this PR separates the requests to (1) prefills and (2) decodes, and executes each one of them separately. Google guys are working on a new TPU attention kernel that will allow mixing prefills and decodes, the moment it is ready, we will be able to remove the separation logic and unify the requests (which will also provide better performance).

Notes:

  1. @mgoin verified correctness with GSM8K on a TPU instance
  2. No TP > 1 support yet
  3. Only greedy sampler for now
  4. V1 code had no support for multiple arches (this PR supports for CUDA and TPU), and this results in code duplications that are avoided as much as possible by introducing base classes for worker and model runner.
  5. Not performance optimized yet

Follow up tasks (maybe I missed something):

  1. Add all sampler options
  2. Add prefix caching (currently supported in V0 TPU)
  3. Add prefill chunking
  4. Integrate with Google new super attention kernel to support mixing for prefills and decodes
  5. Optimize

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link

mergify bot commented Jan 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexm-neuralmagic.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Comment on lines +382 to +261
return PrefillInputData(
request_ids=prefill_request_ids,
prompt_lens=prefill_prompt_lens,
token_ids=prefill_token_ids,
position_ids=prefill_position_ids,
attn_metadata=prefill_attn_metadata,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the PrefillInputData data structure, and make it consistent with gpu_model_runner ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be removed the moment Google provides the new attention kernel that supports chunked prefill.

vllm/v1/worker/tpu_model_runner.py Show resolved Hide resolved
vllm/v1/worker/tpu_model_runner.py Show resolved Hide resolved
effective_query_lens=None,
))

def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is almost identical with current gpu_model_runner implementation, consider reusing instead of duplicating ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added model_runner_base.py to hold common funcs

@mgoin
Copy link
Member

mgoin commented Jan 13, 2025

Successfully ran an eval on GSM8k

VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=Qwen/Qwen2.5-1.5B-Instruct,max_model_len=2048,max_num_seqs=512 --tasks gsm8k --num_fewshot 5 --batch_size auto
...
vllm (pretrained=Qwen/Qwen2.5-1.5B-Instruct,max_model_len=2048,max_num_seqs=512), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5989|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.5428|±  |0.0137|

parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
if envs.VLLM_USE_V1:
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TRUWorker"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TRUWorker"
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"

@@ -0,0 +1,148 @@
"""A GPU worker class."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""A GPU worker class."""
"""A TPU worker class."""

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch! :)


# TPU only supports DYNAMO_ONCE compilation level
if (compilation_config.level == CompilationLevel.NO_COMPILATION
or compilation_config.level == CompilationLevel.PIECEWISE):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assert below makes sure it fails when compilation_config.level < CompilationLevel.PIECEWISE. So do you still need to check if compilation_config.level == CompilationLevel.PIECEWISE here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove it

if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
if envs.VLLM_USE_V1:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean in V1, there is no distinction between MultiStepTPUWorker and TPUWorker?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of V1 is to remove the need for multistep scheduling so we can simplify the scheduler.

# TODO: Add support for these
if envs.VLLM_USE_V1:
if vllm_config.cache_config.enable_prefix_caching:
logger.info("[V1][TPU] Disable prefix caching")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should here and the logger.info below be a logger.error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to warning

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert these changes

Comment on lines 65 to 73
if (compilation_config.level == CompilationLevel.NO_COMPILATION
or compilation_config.level == CompilationLevel.PIECEWISE):
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (compilation_config.level == CompilationLevel.NO_COMPILATION
or compilation_config.level == CompilationLevel.PIECEWISE):
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
logger.warning("[TPU] Unsupported compilation level "
f"{compilation_config.level}. Forcing DYNAMO_ONCE.")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

Comment on lines 70 to 82
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."
("Current compilation level = {} but needs to be less"
" than {}".format(
compilation_config.level,
CompilationLevel.PIECEWISE))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just remove this assert entirely and leave the above override+log

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
if envs.VLLM_USE_V1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of V1 is to remove the need for multistep scheduling so we can simplify the scheduler.

# TODO: Add support for these
if envs.VLLM_USE_V1:
if vllm_config.cache_config.enable_prefix_caching:
logger.info("[V1][TPU] Disable prefix caching")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.info("[V1][TPU] Disable prefix caching")
logger.warning("[V1][TPU] Disabling prefix caching")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

vllm/v1/worker/gpu_model_runner.py Show resolved Hide resolved
Comment on lines 248 to 253
# TODO: Remove prompt_len param here
prefill_attn_metadata.append(
PallasMetadata(
num_prefills=1,
num_prefill_tokens=prompt_len, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping.to(self.device),
multi_modal_placeholder_index_maps=None,
block_tables=None,
context_lens=None,
effective_query_lens=None,
))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you address this TODO?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

assert req_id is not None
req_state = self.requests[req_id]

# TODO: ASSERT NO CHUNKED PREFILL.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implement this TODO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the current assert combo is good enough

scheduler_output.num_scheduled_tokens[req_id])
assert seq_len == req_state.num_tokens

# TODO: Verify if req_id_to_index mapping is needed here!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, it is an old comment

Comment on lines 450 to 452
# TODO: ASSERT NO PREFIX CACHING.
assert req_state.num_computed_tokens == 0
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])

# TODO: ASSERT NO CHUNKED PREFILL.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make these asserts at the initialization level? Why would you need to assert this for each request?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are now inside tpu.py of the platform, and here are just in case something changes in the code and messes something. All of these will change the moment we have chunked prefill attn kernel.

@mergify mergify bot added the ci/build label Jan 16, 2025
Comment on lines 520 to 508
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you build these dummy tensors each time rather than allocating the max in the initializer and taking slices for each run like the gpu_model_runner?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taking slices will result in copies as well, no?

# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In v0 folder, there is a tpu_worker.py, tpu_model_runner.py, pallas.py. Could you summarize how the 3 files in the v1 folder differ from the ones in the v0 folder respectively?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The architecture of V1 is a slightly different than from V0, which required changing APIs. To avoid conflicts, when V1 was implemented, these files were duplicated (with necessary changes) for the NVIDIA backend. In this PR, we do the same for TPU backend, but also refactor the code to *_base.py classes to avoid code duplications (if possible)

Copy link
Collaborator Author

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin @vanbasten23 thanks for the review comments!


# TPU only supports DYNAMO_ONCE compilation level
if (compilation_config.level == CompilationLevel.NO_COMPILATION
or compilation_config.level == CompilationLevel.PIECEWISE):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove it

Comment on lines 65 to 73
if (compilation_config.level == CompilationLevel.NO_COMPILATION
or compilation_config.level == CompilationLevel.PIECEWISE):
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

Comment on lines 70 to 82
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."
("Current compilation level = {} but needs to be less"
" than {}".format(
compilation_config.level,
CompilationLevel.PIECEWISE))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# TODO: Add support for these
if envs.VLLM_USE_V1:
if vllm_config.cache_config.enable_prefix_caching:
logger.info("[V1][TPU] Disable prefix caching")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to warning

# TODO: Add support for these
if envs.VLLM_USE_V1:
if vllm_config.cache_config.enable_prefix_caching:
logger.info("[V1][TPU] Disable prefix caching")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

scheduler_output.num_scheduled_tokens[req_id])
assert seq_len == req_state.num_tokens

# TODO: Verify if req_id_to_index mapping is needed here!
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, it is an old comment

Comment on lines 450 to 452
# TODO: ASSERT NO PREFIX CACHING.
assert req_state.num_computed_tokens == 0
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])

# TODO: ASSERT NO CHUNKED PREFILL.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are now inside tpu.py of the platform, and here are just in case something changes in the code and messes something. All of these will change the moment we have chunked prefill attn kernel.

Comment on lines 520 to 508
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taking slices will result in copies as well, no?

@@ -0,0 +1,148 @@
"""A GPU worker class."""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch! :)

# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The architecture of V1 is a slightly different than from V0, which required changing APIs. To avoid conflicts, when V1 was implemented, these files were duplicated (with necessary changes) for the NVIDIA backend. In this PR, we do the same for TPU backend, but also refactor the code to *_base.py classes to avoid code duplications (if possible)

Copy link

mergify bot commented Jan 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexm-redhat.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@@ -89,4 +89,4 @@ repos:
name: Suggestion
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
language: system
verbose: true
verbose: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

@@ -8,15 +8,15 @@
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

@@ -34,4 +34,4 @@ run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/worker
run_mypy vllm/v1
run_mypy vllm/v1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Signed-off-by: Alexander Matveev <[email protected]>
@vanbasten23
Copy link

Hi @alexm-redhat , thanks for adding vLLm v1 support for TPU!
One quick question, this vLLM slides mentioned a few key changes in vLLM v1:

  • Simplified scheduler
  • Incremental input preparation
  • Piecewise CUDA graphs
  • Enhanced API server
  • More efficient Prefix caching
  • Fine-grained scheduling for VLMs

could you help mark which changes are included in this PR and which are to be made in the future PRs?
cc @miladm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants