Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriele committed Feb 14, 2025
1 parent deee38e commit bc8da03
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 33 deletions.
2 changes: 1 addition & 1 deletion include/flexflow/ops/argmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "flexflow/ops/argmax_params.h"
#include "flexflow/utils/memory_allocator.h"
#include "raft/core/device_resources.hpp"
#include <unordered_map>
// #include <unordered_map>

namespace FlexFlow {

Expand Down
10 changes: 5 additions & 5 deletions src/ops/argmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
#include "flexflow/model.h"
#include "flexflow/utils/hash_utils.h"
#include "legion/legion_utilities.h"
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
#include "flexflow/utils/cuda_helper.h"
#else
#include "flexflow/utils/hip_helper.h"
#endif
// #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
// #include "flexflow/utils/cuda_helper.h"
// #else
// #include "flexflow/utils/hip_helper.h"
// #endif

namespace FlexFlow {
// declare Legion names
Expand Down
10 changes: 5 additions & 5 deletions src/ops/argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "flexflow/ops/argmax.h"
#include "flexflow/utils/cuda_helper.h"
#include <cub/cub.cuh>
#include "raft/matrix/detail/select_k.cuh"
// #include "raft/matrix/detail/select_k.cuh"

namespace FlexFlow {

Expand All @@ -25,8 +25,8 @@ __global__ void init_offset(int batch_size,
int total_eles,
int *d_offsets) {
CUDA_KERNEL_LOOP(i, (total_eles) / vocab_size + 1) {
// if (i % vocab_size == 0) {
// d_offsets[i / vocab_size] = i;
// if (i % vocab_size == 0) {weace
// d_offsets[i / vocab_size] = i
// }
d_offsets[i] = i * vocab_size;
}
Expand Down Expand Up @@ -89,8 +89,8 @@ void ArgMax::forward_kernel(ArgMaxMeta *m,
batch_size,
m->beam_search);
// // now run arg topk
// // assert(bc->num_active_requests() >= 0);
// now run arg topk
// assert(bc->num_active_requests() >= 0);
// if (m->device_resources.find(stream) == m->device_resources.end()) {
// m->device_resources[stream] = new raft::device_resources(stream);
// }
Expand Down
48 changes: 32 additions & 16 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1953,16 +1953,21 @@ BatchConfig RequestManager::prepare_suffix_decoding_batch_config() {
new_bc.tokensInfo[new_bc.num_tokens].token_id = request.tokens.back();
new_bc.num_tokens++;
// add candidate tokens to the batch config
std::vector<int> depths = {0}; // root has depth 0
for (int i=1; i< request.suffix_decoding_best_parents.size(); i++) {
depths.push_back(depths[request.suffix_decoding_best_parents[i]] + 1);
}
for (int i = 0; i < request.suffix_decoding_best_token_ids.size(); i++) {
new_bc.tokensInfo[new_bc.num_tokens].request_index = request_index;
new_bc.tokensInfo[new_bc.num_tokens].abs_index_in_request =
request.tokens.size() + i;
new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request =
request.tokens.size() + i;
request.tokens.size() + depths[i];
new_bc.tokensInfo[new_bc.num_tokens].token_id =
request.suffix_decoding_best_token_ids[i];
new_bc.num_tokens++;
}

new_bc.requestsInfo[request_index].num_tokens_in_batch =
request.suffix_decoding_best_token_ids.size() +
1; // +1 for the bonus token
Expand Down Expand Up @@ -1995,6 +2000,7 @@ BatchConfig RequestManager::prepare_suffix_decoding_batch_config() {
if (profiling_requests[request.guid].llm_decoding_steps == 0) {
profiling_requests[request.guid].start_decoding_time = start_time;
}

}
long long int end_time = Realm::Clock::current_time_in_microseconds();
profiling.tree_operation_step_times.push_back((double)(end_time - start_time) * 1e-3);
Expand All @@ -2008,6 +2014,8 @@ BatchConfig RequestManager::prepare_suffix_decoding_batch_config() {
profiling_requests[guid].speculation_end_timestamp = end_time;
}

assert(new_bc.num_tokens <= max_tokens_per_batch);

if (verbose) {
std::cout << "prepare_suffix_decoding_batch_config NEW batchconfig:"
<< std::endl;
Expand Down Expand Up @@ -2829,9 +2837,9 @@ void RequestManager::get_verify_results_greedy(
profiling.generated_tokens_per_step.push_back(total_nb_generated_tokens);
}

// void print_ir_debug_info(InferenceResult const &llm_verify_result, int num_tokens, int topk) {
// std::cout << "Logits: ";
// for (int i=0; i<num_tokens; i++) {
// void print_ir_debug_info(InferenceResult const &llm_verify_result, int num_tokens, int topk, int start_idx=0) {
// std::cout << "Logits: ";
// for (int i=start_idx; i<start_idx+num_tokens; i++) {
// std::cout << i << ": [";
// for (int j=0; j<topk; j++) {
// std::cout << "("
Expand All @@ -2843,7 +2851,7 @@ void RequestManager::get_verify_results_greedy(
// std::cout << "]\n";
// }
// std::cout << "Argmax logits: ";
// for (int i=0; i<num_tokens; i++) {
// for (int i=start_idx; i<start_idx+num_tokens; i++) {
// std::cout << "("
// << llm_verify_result.token_ids[i]
// << ","
Expand Down Expand Up @@ -2911,23 +2919,20 @@ void RequestManager::get_verify_results_suffix_decoding(
for (int i = 0; i < (int)request.suffix_decoding_best_token_ids.size(); i++) {
int current_token = request.suffix_decoding_best_token_ids[i];
int parent_idx = request.suffix_decoding_best_parents[i];
int current_result = llm_verify_result.token_ids[llm_result_offset + parent_idx+1];
int last_parent_idx = (last_accepted_token_idx == -1) ? -2 : request.suffix_decoding_best_parents[last_accepted_token_idx];
int current_result = llm_verify_result.token_ids[llm_result_offset + parent_idx + 1];
// int last_parent_idx = (last_accepted_token_idx == -1) ? -2 : request.suffix_decoding_best_parents[last_accepted_token_idx];
TokenId last_accepted_token = request.tokens.back();
TokenId current_parent_token = (parent_idx == -1) ? request.tokens.back()
: request.suffix_decoding_best_token_ids[parent_idx];
if (verbose) {
printf("\ti=%i: {current_token: %d, current_result: %d, last_accepted_token: %d, current_parent_token: %d, last_parent_idx: %d, parent_idx: %d, last_accepted_token_idx: %i}\n",
i, current_token, current_result, last_accepted_token, current_parent_token, last_parent_idx, parent_idx, last_accepted_token_idx);
i, current_token, current_result, last_accepted_token, current_parent_token, 0, parent_idx, last_accepted_token_idx);
}
// Accept token if:
// (1) it matches the result,
// (2) last accepted token is the token's parent
// (3) no other token has been accepted in this layer
if (current_token == current_result &&
last_accepted_token == current_parent_token &&
last_accepted_token_idx == parent_idx &&
last_parent_idx != parent_idx) {

// accept tokens if:
// (1) last accepted token is the token's parent (not another identical one)
// (2) the result corresponding to the parent token is the same as the current token
if (last_accepted_token_idx == parent_idx && current_result == current_token) {
request.committed_tokens.push_back(Request::CommittedToken(
llm_cache_size + i, committed_token_index, current_token));
request.tokens.push_back(current_token);
Expand All @@ -2943,12 +2948,23 @@ void RequestManager::get_verify_results_suffix_decoding(
// 3. Add the bonus token
int bonus_token_idx = last_accepted_token_idx+1;
if (verbose) {
std::cout << "llm_result_offset: " << llm_result_offset << std::endl;
std::cout << "inference results: ";
for (int i=llm_result_offset; i<llm_verify_result.num_token_ids; i++) {
std::cout << llm_verify_result.token_ids[i] << " ";
}
std::cout << std::endl;
std::cout << "last_accepted_token_idx: " << last_accepted_token_idx << std::endl;
std::cout << "accepted_tokens: ";
for (int i=1; i<request.committed_tokens.size(); i++) {
std::cout << request.committed_tokens[i].token_id << " ";
}
std::cout << std::endl;
// if (request.tokens.size() < 70 + request) {
// std::cout << "found it!" << std::endl;
// print_ir_debug_info(llm_verify_result, request.suffix_decoding_best_token_ids.size(), 5, llm_result_offset);
// }

// print_ir_debug_info(llm_verify_result, request.suffix_decoding_best_token_ids.size(), 5);
std::cout << "bonus_token_idx: " << bonus_token_idx << std::endl;
std::cout << "bonus token: " << llm_verify_result.token_ids[llm_result_offset + bonus_token_idx] << std::endl;
Expand Down
12 changes: 6 additions & 6 deletions suffix_decoding/benchmark_suffix_tree.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ online_tree_update=(
##################### Dataset parameters #####################
traces=(
cortex
# spider
# magicoder
# wildchat
spider
magicoder
wildchat
)
trace_files=(
${SUFFIX_DECODING_TRACES_FOLDER}/cortex-llama3.1-70b.json
# ${SUFFIX_DECODING_TRACES_FOLDER}/spider-llama3.1-70b.json
# ${SUFFIX_DECODING_TRACES_FOLDER}/magicoder25k-llama3.1-70b.json
# ${SUFFIX_DECODING_TRACES_FOLDER}/wildchat25k-llama3.1-70b.json
${SUFFIX_DECODING_TRACES_FOLDER}/spider-llama3.1-70b.json
${SUFFIX_DECODING_TRACES_FOLDER}/magicoder25k-llama3.1-70b.json
${SUFFIX_DECODING_TRACES_FOLDER}/wildchat25k-llama3.1-70b.json
)


Expand Down

0 comments on commit bc8da03

Please sign in to comment.