Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriele committed Feb 13, 2025
1 parent 4d75ff6 commit deee38e
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 69 deletions.
5 changes: 3 additions & 2 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,9 @@ struct InferenceResult {
float gumbel_logits[BatchConfig::MAX_NUM_TOKENS *
BatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];

BatchConfig::TokenId debug_topk_tokens[BatchConfig::MAX_NUM_TOKENS * BatchConfig::MAX_K_LOGITS];
half debug_topk_logits[BatchConfig::MAX_NUM_TOKENS * BatchConfig::MAX_K_LOGITS];
// BatchConfig::TokenId debug_topk_tokens[BatchConfig::MAX_NUM_TOKENS * BatchConfig::MAX_K_LOGITS];
// half debug_topk_logits[BatchConfig::MAX_NUM_TOKENS * BatchConfig::MAX_K_LOGITS];
// half debug_argmax_logits[BatchConfig::MAX_NUM_TOKENS];
InferenceResult() : num_token_ids(0), num_gumbel_logits(0) {}
InferenceResult(InferenceResult const &other);
friend std::ostream &operator<<(std::ostream &os, InferenceResult const &ir);
Expand Down
9 changes: 5 additions & 4 deletions include/flexflow/ops/argmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ class ArgMaxMeta : public OpMeta {
size_t temp_storage_bytes = 0;
int *d_offsets;
void *d_out;
half *topk_out_vals;
int *topk_out_indices;
std::unordered_map<cudaStream_t, raft::device_resources *> device_resources;
int arg_k;
// half *topk_out_vals;
// int *topk_out_indices;
// half *argmax_logits;
// std::unordered_map<cudaStream_t, raft::device_resources *> device_resources;
// int arg_k;
Realm::RegionInstance reserveInst;
ArgMaxMeta(FFHandler handler,
Op const *op,
Expand Down
10 changes: 6 additions & 4 deletions src/ops/argmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,13 @@ InferenceResult

download_tensor<BatchConfig::TokenId>(indices.get_int32_ptr(), ir.token_ids, batch_size);

memset(ir.debug_topk_tokens, 0, BatchConfig::MAX_NUM_TOKENS * BatchConfig::MAX_K_LOGITS * sizeof(BatchConfig::TokenId));
memset(ir.debug_topk_logits, 0, BatchConfig::MAX_NUM_TOKENS * BatchConfig::MAX_K_LOGITS * sizeof(half));
// memset(ir.debug_topk_tokens, 0, BatchConfig::MAX_NUM_TOKENS * BatchConfig::MAX_K_LOGITS * sizeof(BatchConfig::TokenId));
// memset(ir.debug_topk_logits, 0, BatchConfig::MAX_NUM_TOKENS * BatchConfig::MAX_K_LOGITS * sizeof(half));
// memset(ir.debug_argmax_logits, 0, BatchConfig::MAX_NUM_TOKENS * sizeof(half));

download_tensor<BatchConfig::TokenId>(m->topk_out_indices, ir.debug_topk_tokens, batch_size * m->arg_k);
download_tensor<half>(m->topk_out_vals, ir.debug_topk_logits, batch_size * m->arg_k);
// download_tensor<BatchConfig::TokenId>(m->topk_out_indices, ir.debug_topk_tokens, batch_size * m->arg_k);
// download_tensor<half>(m->topk_out_vals, ir.debug_topk_logits, batch_size * m->arg_k);
// download_tensor<half>(m->argmax_logits, ir.debug_argmax_logits, batch_size);

return ir;
}
Expand Down
60 changes: 33 additions & 27 deletions src/ops/argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ template <typename DT>
__global__ void copy_result(cub::KeyValuePair<int, DT> *d_out,
int *indices,
float *prob_ptr,
// DT *debug_argmax_logits,
int batch_size,
bool beam_search) {
CUDA_KERNEL_LOOP(i, batch_size) {
indices[i] = d_out[i].key;
if (beam_search) {
prob_ptr[i] = static_cast<float>(d_out[i].value);
}
// debug_argmax_logits[i] = d_out[i].value;
}
}

Expand Down Expand Up @@ -83,25 +85,26 @@ void ArgMax::forward_kernel(ArgMaxMeta *m,
stream>>>(static_cast<cub::KeyValuePair<int, DT> *>(m->d_out),
indices_ptr,
prob_ptr,
// (DT*)m->argmax_logits,
batch_size,
m->beam_search);
// now run arg topk
// assert(bc->num_active_requests() >= 0);
if (m->device_resources.find(stream) == m->device_resources.end()) {
m->device_resources[stream] = new raft::device_resources(stream);
}
raft::device_resources *handle = m->device_resources[stream];
raft::matrix::detail::select_k(*handle,
input_ptr,
(int *)nullptr,
batch_size,
(size_t)length,
m->arg_k,
(DT*)(m->topk_out_vals),
m->topk_out_indices,
/*select_min=*/false,
false);
// // now run arg topk
// // assert(bc->num_active_requests() >= 0);
// if (m->device_resources.find(stream) == m->device_resources.end()) {
// m->device_resources[stream] = new raft::device_resources(stream);
// }
// raft::device_resources *handle = m->device_resources[stream];
// raft::matrix::detail::select_k(*handle,
// input_ptr,
// (int *)nullptr,
// batch_size,
// (size_t)length,
// m->arg_k,
// (DT*)(m->topk_out_vals),
// m->topk_out_indices,
// /*select_min=*/false,
// false);
// print_tensor<int>(indices_ptr, 4, "argmax op");
Expand Down Expand Up @@ -182,12 +185,14 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler,
: sizeof(cub::KeyValuePair<int, half>) * batch_size) +
prob_size * sizeof(float);
arg_k = 5;
assert(data_type == DT_HALF);
size_t topk_num_elements = arg_k * batch_size;
size_t topk_out_vals_size = topk_num_elements * sizeof(half);
size_t topk_out_indices_size = topk_num_elements * sizeof(int);
total_size += topk_out_vals_size + topk_out_indices_size;
// arg_k = 5;
// assert(data_type == DT_HALF);
// size_t topk_num_elements = arg_k * batch_size;
// size_t topk_out_vals_size = topk_num_elements * sizeof(half);
// size_t topk_out_indices_size = topk_num_elements * sizeof(int);
// total_size += topk_out_vals_size + topk_out_indices_size;
// size_t argmax_logits_size = batch_size * sizeof(half);
// total_size += argmax_logits_size;
gpu_mem_allocator.create_legion_instance(reserveInst, total_size);
d_offsets = gpu_mem_allocator.allocate_instance<int>(d_offsets_size);
Expand Down Expand Up @@ -229,8 +234,9 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler,
}
topk_out_vals = gpu_mem_allocator.allocate_instance<half>(topk_num_elements);
topk_out_indices = gpu_mem_allocator.allocate_instance<int>(topk_num_elements);
// topk_out_vals = gpu_mem_allocator.allocate_instance<half>(topk_num_elements);
// topk_out_indices = gpu_mem_allocator.allocate_instance<int>(topk_num_elements);
// argmax_logits = gpu_mem_allocator.allocate_instance<half>(batch_size);
gpu_mem_allocator.create_legion_instance(reserveInst, temp_storage_bytes);
d_temp_storage =
Expand All @@ -242,8 +248,8 @@ ArgMaxMeta::~ArgMaxMeta(void) {
if (reserveInst != Realm::RegionInstance::NO_INST) {
reserveInst.destroy();
}
for (auto &kv : device_resources) {
delete kv.second;
}
// for (auto &kv : device_resources) {
// delete kv.second;
// }
}
}; // namespace FlexFlow
15 changes: 9 additions & 6 deletions src/runtime/batch_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,15 @@ InferenceResult::InferenceResult(InferenceResult const &other) {
std::copy(other.gumbel_logits,
other.gumbel_logits + num_gumbel_logits,
gumbel_logits);
std::copy(other.debug_topk_tokens,
other.debug_topk_tokens + num_token_ids*5,
debug_topk_tokens);
std::copy(other.debug_topk_logits,
other.debug_topk_logits + num_token_ids*5,
debug_topk_logits);
// std::copy(other.debug_topk_tokens,
// other.debug_topk_tokens + num_token_ids*5,
// debug_topk_tokens);
// std::copy(other.debug_topk_logits,
// other.debug_topk_logits + num_token_ids*5,
// debug_topk_logits);
// std::copy(other.debug_argmax_logits,
// other.debug_argmax_logits + num_token_ids,
// debug_argmax_logits);
}

StreamingCacheInfo::StreamingCacheInfo() : StreamingCacheInfo(0, 0) {}
Expand Down
45 changes: 27 additions & 18 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1856,7 +1856,7 @@ void RequestManager::populate_best_suffix_tree_candidates(Request &request) {
}
profiling_requests[request.guid].prefix_length_per_step.push_back(request.suffix_decoding_best_prefix_length);

if (true) {
if (verbose) {
std::cout << "Populated best suffix tree candidates for request " << request.guid << " with score " << request.suffix_decoding_best_score << std::endl;
std::cout << "Best prefix length: " << request.suffix_decoding_best_prefix_length << std::endl;
std::cout << "Best prefix: ";
Expand Down Expand Up @@ -2829,21 +2829,29 @@ void RequestManager::get_verify_results_greedy(
profiling.generated_tokens_per_step.push_back(total_nb_generated_tokens);
}

void print_ir_debug_info(InferenceResult const &llm_verify_result, int num_tokens, int topk) {
std::cout << "Logits: ";
for (int i=0; i<num_tokens; i++) {
std::cout << i << ": [";
for (int j=0; j<topk; j++) {
std::cout << "("
<< llm_verify_result.debug_topk_tokens[i*topk + j]
<< ","
<< std::fixed << std::setprecision(3) << (float)llm_verify_result.debug_topk_logits[i*topk + j]
<< "), ";
}
std::cout << "]\n";
}
std::cout << std::endl;
}
// void print_ir_debug_info(InferenceResult const &llm_verify_result, int num_tokens, int topk) {
// std::cout << "Logits: ";
// for (int i=0; i<num_tokens; i++) {
// std::cout << i << ": [";
// for (int j=0; j<topk; j++) {
// std::cout << "("
// << llm_verify_result.debug_topk_tokens[i*topk + j]
// << ","
// << std::fixed << std::setprecision(3) << (float)llm_verify_result.debug_topk_logits[i*topk + j]
// << "), ";
// }
// std::cout << "]\n";
// }
// std::cout << "Argmax logits: ";
// for (int i=0; i<num_tokens; i++) {
// std::cout << "("
// << llm_verify_result.token_ids[i]
// << ","
// << std::fixed << std::setprecision(3) << (float)llm_verify_result.debug_argmax_logits[i]
// << "), ";
// }
// std::cout << std::endl;
// }

void RequestManager::get_verify_results_suffix_decoding(
InferenceResult const &llm_verify_result) {
Expand Down Expand Up @@ -2918,6 +2926,7 @@ void RequestManager::get_verify_results_suffix_decoding(
// (3) no other token has been accepted in this layer
if (current_token == current_result &&
last_accepted_token == current_parent_token &&
last_accepted_token_idx == parent_idx &&
last_parent_idx != parent_idx) {
request.committed_tokens.push_back(Request::CommittedToken(
llm_cache_size + i, committed_token_index, current_token));
Expand All @@ -2933,14 +2942,14 @@ void RequestManager::get_verify_results_suffix_decoding(

// 3. Add the bonus token
int bonus_token_idx = last_accepted_token_idx+1;
if (true) {
if (verbose) {
std::cout << "last_accepted_token_idx: " << last_accepted_token_idx << std::endl;
std::cout << "accepted_tokens: ";
for (int i=1; i<request.committed_tokens.size(); i++) {
std::cout << request.committed_tokens[i].token_id << " ";
}
std::cout << std::endl;
print_ir_debug_info(llm_verify_result, request.suffix_decoding_best_token_ids.size(), 5);
// print_ir_debug_info(llm_verify_result, request.suffix_decoding_best_token_ids.size(), 5);
std::cout << "bonus_token_idx: " << bonus_token_idx << std::endl;
std::cout << "bonus token: " << llm_verify_result.token_ids[llm_result_offset + bonus_token_idx] << std::endl;
std::cout << "found eos? " << found_eos << std::endl;
Expand Down
16 changes: 8 additions & 8 deletions suffix_decoding/benchmark_suffix_tree.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ max_tree_depth=64
# CSIZE=100000
matching_strategies=(
linear_token_path
# dynamic_token_tree
dynamic_token_tree
)
online_tree_update=(
# true
Expand All @@ -50,7 +50,7 @@ traces=(
# wildchat
)
trace_files=(
${SUFFIX_DECODING_TRACES_FOLDER}/cortex-llama3.1-70b_debug.json
${SUFFIX_DECODING_TRACES_FOLDER}/cortex-llama3.1-70b.json
# ${SUFFIX_DECODING_TRACES_FOLDER}/spider-llama3.1-70b.json
# ${SUFFIX_DECODING_TRACES_FOLDER}/magicoder25k-llama3.1-70b.json
# ${SUFFIX_DECODING_TRACES_FOLDER}/wildchat25k-llama3.1-70b.json
Expand All @@ -75,13 +75,13 @@ for i in "${!traces[@]}"; do
fi
if [ "$trace" == "cortex" ]; then
partitions=(
# QUESTION_SUGGESTION
# CATEGORIZATION
QUESTION_SUGGESTION
CATEGORIZATION
FEATURE_EXTRACTION
# SQL_FANOUT1
# SQL_FANOUT2
# SQL_FANOUT3
# SQL_COMBINE
SQL_FANOUT1
SQL_FANOUT2
SQL_FANOUT3
SQL_COMBINE
)
else
partitions=(all)
Expand Down

0 comments on commit deee38e

Please sign in to comment.