-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] Add offline test for disaggregated prefill #12418
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Shaoting Feng <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. Please note that we expect users to learn the feature from examples and ideally users can directly use or modify examples for their use cases. So please provide as many comments and explanations as possible. In addition, since it's not trivial for prefill disaggregation to be used in offline inference, could you also elaborate on the scenario?
|
||
prompts = [ | ||
"Hello, my name is", | ||
# "Hi, your name is", # To simulate transmission failure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate? How this comment simulates the failure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to trigger the partial prefill of requests in a batch. The prefill node receives two requests, while the decode node receives three requests. So the decode node will only receive the KV Cache for requests 1 and 3. The decode node will use the KV Cache of requests 1 and 3 and do prefilling on request 2.
This example demonstrates how to use disaggregated prefill and how the decode node manages receiving only a subset of requests within a batch.
|
||
|
||
def run_prefill(prefill_done): | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add comments to explain that we use GPU 0 for prefill and GPU1 for decode.
# "Hi, your name is", # To simulate transmission failure | ||
"Tell me a very long story", | ||
] | ||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_token=1 seems not a very good prefill disaggregated example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comment. This example serves as a proxy API server in the offline case. As shown in the Readme of disaggregated prefill (https://github.com/vllm-project/vllm/tree/main/vllm/distributed/kv_transfer), we should set the max token of prefill node to 1.
Also in the online case, in benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
, the author set prefill_request['max_tokens'] = 1
in handle_request
function.
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", | ||
kv_transfer_config=ktc, | ||
max_model_len=2000, | ||
gpu_memory_utilization=0.8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some GPUs may get OOM with this ratio. Please document the GPU you used and advise the ratio for other GPUs.
# To keep the prefill node running in case the decode node is not done | ||
try: | ||
while True: | ||
time.sleep(1) | ||
except KeyboardInterrupt: | ||
print("Script stopped by user.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't stop until user presses Carl+C right? Then this is not a good offline example...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The prefill node will automatically terminate once the decode node completes. The code is as follows:
# Terminate the prefill node when decode is finished
decode_process.join()
prefill_process.terminate()
Signed-off-by: Shaoting Feng <[email protected]>
@comaniac @WangErXiao Thank you very much for your feedback. I have added additional comments and addressed your questions. An offline disaggregated prefill example use case is valuable for the disaggregated prefill roadmap, as it simplifies debugging. If there are any further concerns about the code, I would greatly appreciate your input. |
This PR adds an offline test for disaggregated prefill use case.