-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] [V1] TPU support #11936
base: main
Are you sure you want to change the base?
[WIP] [V1] TPU support #11936
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
This pull request has merge conflicts that must be resolved before it can be |
return PrefillInputData( | ||
request_ids=prefill_request_ids, | ||
prompt_lens=prefill_prompt_lens, | ||
token_ids=prefill_token_ids, | ||
position_ids=prefill_position_ids, | ||
attn_metadata=prefill_attn_metadata, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the PrefillInputData
data structure, and make it consistent with gpu_model_runner ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be removed the moment Google provides the new attention kernel that supports chunked prefill.
effective_query_lens=None, | ||
)) | ||
|
||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is almost identical with current gpu_model_runner implementation, consider reusing instead of duplicating ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added model_runner_base.py to hold common funcs
Successfully ran an eval on GSM8k
|
vllm/platforms/tpu.py
Outdated
parallel_config.worker_cls = \ | ||
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" | ||
if envs.VLLM_USE_V1: | ||
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TRUWorker" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TRUWorker" | |
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker" |
vllm/v1/worker/tpu_worker.py
Outdated
@@ -0,0 +1,148 @@ | |||
"""A GPU worker class.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""A GPU worker class.""" | |
"""A TPU worker class.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch! :)
vllm/platforms/tpu.py
Outdated
|
||
# TPU only supports DYNAMO_ONCE compilation level | ||
if (compilation_config.level == CompilationLevel.NO_COMPILATION | ||
or compilation_config.level == CompilationLevel.PIECEWISE): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assert below makes sure it fails when compilation_config.level < CompilationLevel.PIECEWISE. So do you still need to check if compilation_config.level == CompilationLevel.PIECEWISE
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove it
if scheduler_config.is_multi_step: | ||
parallel_config.worker_cls = \ | ||
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" | ||
if envs.VLLM_USE_V1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean in V1, there is no distinction between MultiStepTPUWorker and TPUWorker?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of V1 is to remove the need for multistep scheduling so we can simplify the scheduler.
vllm/platforms/tpu.py
Outdated
# TODO: Add support for these | ||
if envs.VLLM_USE_V1: | ||
if vllm_config.cache_config.enable_prefix_caching: | ||
logger.info("[V1][TPU] Disable prefix caching") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should here and the logger.info below be a logger.error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to warning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert these changes
vllm/platforms/tpu.py
Outdated
if (compilation_config.level == CompilationLevel.NO_COMPILATION | ||
or compilation_config.level == CompilationLevel.PIECEWISE): | ||
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (compilation_config.level == CompilationLevel.NO_COMPILATION | |
or compilation_config.level == CompilationLevel.PIECEWISE): | |
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") | |
if compilation_config.level != CompilationLevel.DYNAMO_ONCE: | |
logger.warning("[TPU] Unsupported compilation level " | |
f"{compilation_config.level}. Forcing DYNAMO_ONCE.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
vllm/platforms/tpu.py
Outdated
assert compilation_config.level < CompilationLevel.PIECEWISE,\ | ||
"TPU does not support Inductor." | ||
("Current compilation level = {} but needs to be less" | ||
" than {}".format( | ||
compilation_config.level, | ||
CompilationLevel.PIECEWISE)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just remove this assert entirely and leave the above override+log
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if scheduler_config.is_multi_step: | ||
parallel_config.worker_cls = \ | ||
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" | ||
if envs.VLLM_USE_V1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of V1 is to remove the need for multistep scheduling so we can simplify the scheduler.
vllm/platforms/tpu.py
Outdated
# TODO: Add support for these | ||
if envs.VLLM_USE_V1: | ||
if vllm_config.cache_config.enable_prefix_caching: | ||
logger.info("[V1][TPU] Disable prefix caching") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger.info("[V1][TPU] Disable prefix caching") | |
logger.warning("[V1][TPU] Disabling prefix caching") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
vllm/v1/worker/tpu_model_runner.py
Outdated
# TODO: Remove prompt_len param here | ||
prefill_attn_metadata.append( | ||
PallasMetadata( | ||
num_prefills=1, | ||
num_prefill_tokens=prompt_len, # NOTE: This is not used. | ||
num_decode_tokens=0, | ||
slot_mapping=slot_mapping.to(self.device), | ||
multi_modal_placeholder_index_maps=None, | ||
block_tables=None, | ||
context_lens=None, | ||
effective_query_lens=None, | ||
)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you address this TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
vllm/v1/worker/tpu_model_runner.py
Outdated
assert req_id is not None | ||
req_state = self.requests[req_id] | ||
|
||
# TODO: ASSERT NO CHUNKED PREFILL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement this TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the current assert combo is good enough
vllm/v1/worker/tpu_model_runner.py
Outdated
scheduler_output.num_scheduled_tokens[req_id]) | ||
assert seq_len == req_state.num_tokens | ||
|
||
# TODO: Verify if req_id_to_index mapping is needed here! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed, it is an old comment
vllm/v1/worker/tpu_model_runner.py
Outdated
# TODO: ASSERT NO PREFIX CACHING. | ||
assert req_state.num_computed_tokens == 0 | ||
seq_len = (req_state.num_computed_tokens + | ||
scheduler_output.num_scheduled_tokens[req_id]) | ||
|
||
# TODO: ASSERT NO CHUNKED PREFILL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you make these asserts at the initialization level? Why would you need to assert this for each request?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are now inside tpu.py of the platform, and here are just in case something changes in the code and messes something. All of these will change the moment we have chunked prefill attn kernel.
vllm/v1/worker/tpu_model_runner.py
Outdated
token_ids = torch.zeros((batch_size, seq_len), | ||
dtype=torch.int32, | ||
device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you build these dummy tensors each time rather than allocating the max in the initializer and taking slices for each run like the gpu_model_runner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
taking slices will result in copies as well, no?
# Use persistent cache to avoid XLA recompilation. | ||
# NOTE(woosuk): Set per-rank cache path since different ranks | ||
# can have slightly different XLA graphs. | ||
world_size = self.parallel_config.world_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In v0 folder, there is a tpu_worker.py, tpu_model_runner.py, pallas.py. Could you summarize how the 3 files in the v1 folder differ from the ones in the v0 folder respectively?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The architecture of V1 is a slightly different than from V0, which required changing APIs. To avoid conflicts, when V1 was implemented, these files were duplicated (with necessary changes) for the NVIDIA backend. In this PR, we do the same for TPU backend, but also refactor the code to *_base.py classes to avoid code duplications (if possible)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin @vanbasten23 thanks for the review comments!
vllm/platforms/tpu.py
Outdated
|
||
# TPU only supports DYNAMO_ONCE compilation level | ||
if (compilation_config.level == CompilationLevel.NO_COMPILATION | ||
or compilation_config.level == CompilationLevel.PIECEWISE): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove it
vllm/platforms/tpu.py
Outdated
if (compilation_config.level == CompilationLevel.NO_COMPILATION | ||
or compilation_config.level == CompilationLevel.PIECEWISE): | ||
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
vllm/platforms/tpu.py
Outdated
assert compilation_config.level < CompilationLevel.PIECEWISE,\ | ||
"TPU does not support Inductor." | ||
("Current compilation level = {} but needs to be less" | ||
" than {}".format( | ||
compilation_config.level, | ||
CompilationLevel.PIECEWISE)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
vllm/platforms/tpu.py
Outdated
# TODO: Add support for these | ||
if envs.VLLM_USE_V1: | ||
if vllm_config.cache_config.enable_prefix_caching: | ||
logger.info("[V1][TPU] Disable prefix caching") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to warning
vllm/platforms/tpu.py
Outdated
# TODO: Add support for these | ||
if envs.VLLM_USE_V1: | ||
if vllm_config.cache_config.enable_prefix_caching: | ||
logger.info("[V1][TPU] Disable prefix caching") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
vllm/v1/worker/tpu_model_runner.py
Outdated
scheduler_output.num_scheduled_tokens[req_id]) | ||
assert seq_len == req_state.num_tokens | ||
|
||
# TODO: Verify if req_id_to_index mapping is needed here! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed, it is an old comment
vllm/v1/worker/tpu_model_runner.py
Outdated
# TODO: ASSERT NO PREFIX CACHING. | ||
assert req_state.num_computed_tokens == 0 | ||
seq_len = (req_state.num_computed_tokens + | ||
scheduler_output.num_scheduled_tokens[req_id]) | ||
|
||
# TODO: ASSERT NO CHUNKED PREFILL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are now inside tpu.py of the platform, and here are just in case something changes in the code and messes something. All of these will change the moment we have chunked prefill attn kernel.
vllm/v1/worker/tpu_model_runner.py
Outdated
token_ids = torch.zeros((batch_size, seq_len), | ||
dtype=torch.int32, | ||
device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
taking slices will result in copies as well, no?
vllm/v1/worker/tpu_worker.py
Outdated
@@ -0,0 +1,148 @@ | |||
"""A GPU worker class.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch! :)
# Use persistent cache to avoid XLA recompilation. | ||
# NOTE(woosuk): Set per-rank cache path since different ranks | ||
# can have slightly different XLA graphs. | ||
world_size = self.parallel_config.world_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The architecture of V1 is a slightly different than from V0, which required changing APIs. To avoid conflicts, when V1 was implemented, these files were duplicated (with necessary changes) for the NVIDIA backend. In this PR, we do the same for TPU backend, but also refactor the code to *_base.py classes to avoid code duplications (if possible)
dea6afd
to
c6f526c
Compare
This pull request has merge conflicts that must be resolved before it can be |
@@ -89,4 +89,4 @@ repos: | |||
name: Suggestion | |||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' | |||
language: system | |||
verbose: true | |||
verbose: true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
@@ -8,15 +8,15 @@ | |||
"The future of AI is", | |||
] | |||
# Create a sampling params object. | |||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | |||
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert
@@ -34,4 +34,4 @@ run_mypy vllm/plugins | |||
run_mypy vllm/prompt_adapter | |||
run_mypy vllm/spec_decode | |||
run_mypy vllm/worker | |||
run_mypy vllm/v1 | |||
run_mypy vllm/v1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
Signed-off-by: Alexander Matveev <[email protected]>
Hi @alexm-redhat , thanks for adding vLLm v1 support for TPU!
could you help mark which changes are included in this PR and which are to be made in the future PRs? |
This PR is a rebase and modification of @robertgshaw2-redhat original PR for TPU support in vLLM V1 from 1.5 months ago #10241
Currently, TPU attention kernel has no support for mixing prefills and decodes in the same scheduler iteration. As a result, this PR separates the requests to (1) prefills and (2) decodes, and executes each one of them separately. Google guys are working on a new TPU attention kernel that will allow mixing prefills and decodes, the moment it is ready, we will be able to remove the separation logic and unify the requests (which will also provide better performance).
Notes:
Follow up tasks (maybe I missed something):