Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Operators support for updated speculative inference design #8

Open
chenzhuofu opened this issue Apr 20, 2024 · 6 comments
Open

Operators support for updated speculative inference design #8

chenzhuofu opened this issue Apr 20, 2024 · 6 comments
Assignees

Comments

@chenzhuofu
Copy link
Collaborator

chenzhuofu commented Apr 20, 2024

Related issues

#9 #14 #13

Description

We proposed the inference implementation refactoring which mainly involves Pipeline Split and Struct Smplification, and this result some issues to discuss in operators (kernel) changes. I would list them here, and if I miss out something please feel free to correct me~

1. For splitting prefilling and decode stages

Previously we mix prompt phase and generation phase of caclution in one inference kernel (spec_inc_multihead_self_attention or tree_inc_multihead_self_attention). To support split stages we should also spilt mixed caclution.

But here's a problem. should we provide prompt and generation as two distinct inference kernel ops,
or still provide one op while do conditional branch within it for different stage calculation.
The former approach would force change in operators DAG so I think is not good.

2. For smplifing BatchConfig structure

Trivial changes are adopted. But I haven't fully figured out how we switch from BeamSearchBC to TreeSearchBC.

In BeamSearch version, the last layer of ssm is beam_topk and its output is stored in BeamInferenceResult (using download_tensor). And in TreeSearch version SsmInferenceResult is the same as BeamInferenceResult, so I guess we will still use beam_topk.

But beam_topk use some fields like sub_requests, beamRequestsInfo::probs, which removed from updated TreeSearchBC. Maybe we can discuss how to adapt it.

@zikun-li
Copy link
Collaborator

For the second question, it seems arg_topk can also return SsmInferenceResult. We will use arg_topk as the last operator of the SSM.

@chenzhuofu
Copy link
Collaborator Author

chenzhuofu commented Apr 22, 2024

For question1, I prefer reserving single op structure for multihead attention, how do you think? @zwang86 @zikun-li

To implement this, I need add a field current_phase in BatchConfig, so that op can choose which kernel to execute (PROMPT or GENERATION).

@zwang86
Copy link
Collaborator

zwang86 commented Apr 22, 2024

I believe there are different cuda kernels for prefilling and decoding tokens already implemented by @xinhaoc, but those kernels can still be the same operator.. As discussed with @jiazhihao, we want to use those kernels because they are optimized for their use case.
Hi @xinhaoc, do you have any thoughts?

@zikun-li
Copy link
Collaborator

For speculative decoding, currently we have current_depth in TreeSearchBatchConfig to indicate whether to use prompt kernel (i.e. if current_depth == 0, it means we should run the prompt kernel).

@chenzhuofu
Copy link
Collaborator Author

chenzhuofu commented Apr 22, 2024

I believe there are different cuda kernels for prefilling and decoding tokens already implemented by @xinhaoc, but those kernels can still be the same operator.. As discussed with @jiazhihao, we want to use those kernels because they are optimized for their use case. Hi @xinhaoc, do you have any thoughts?

Yes, in CUDA we have different kernel within one multihead_attention operator. What I mean is "should we split them into two operators like higher-level does?" (Now I think it seems no need :P)

@chenzhuofu
Copy link
Collaborator Author

For speculative decoding, currently we have current_depth in TreeSearchBatchConfig to indicate whether to use prompt kernel (i.e. if current_depth == 0, it means we should run the prompt kernel).

For tree verification, do we have similar method to figure whether in prompt phase?

@lockshaw lockshaw transferred this issue from flexflow/flexflow-train Dec 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants