diff --git a/include/flexflow/ops/argmax.h b/include/flexflow/ops/argmax.h index ad22e6b36..4ab6a24c0 100644 --- a/include/flexflow/ops/argmax.h +++ b/include/flexflow/ops/argmax.h @@ -7,7 +7,7 @@ #include "flexflow/ops/argmax_params.h" #include "flexflow/utils/memory_allocator.h" #include "raft/core/device_resources.hpp" -#include +// #include namespace FlexFlow { diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index b2bf8e1e2..5bd3bc9e3 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -17,11 +17,11 @@ #include "flexflow/model.h" #include "flexflow/utils/hash_utils.h" #include "legion/legion_utilities.h" -#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) -#include "flexflow/utils/cuda_helper.h" -#else -#include "flexflow/utils/hip_helper.h" -#endif +// #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) +// #include "flexflow/utils/cuda_helper.h" +// #else +// #include "flexflow/utils/hip_helper.h" +// #endif namespace FlexFlow { // declare Legion names diff --git a/src/ops/argmax.cu b/src/ops/argmax.cu index 23abc9017..47e8a72dc 100644 --- a/src/ops/argmax.cu +++ b/src/ops/argmax.cu @@ -16,7 +16,7 @@ #include "flexflow/ops/argmax.h" #include "flexflow/utils/cuda_helper.h" #include -#include "raft/matrix/detail/select_k.cuh" +// #include "raft/matrix/detail/select_k.cuh" namespace FlexFlow { @@ -25,8 +25,8 @@ __global__ void init_offset(int batch_size, int total_eles, int *d_offsets) { CUDA_KERNEL_LOOP(i, (total_eles) / vocab_size + 1) { - // if (i % vocab_size == 0) { - // d_offsets[i / vocab_size] = i; + // if (i % vocab_size == 0) {weace + // d_offsets[i / vocab_size] = i // } d_offsets[i] = i * vocab_size; } @@ -89,8 +89,8 @@ void ArgMax::forward_kernel(ArgMaxMeta *m, batch_size, m->beam_search); - // // now run arg topk - // // assert(bc->num_active_requests() >= 0); + // now run arg topk + // assert(bc->num_active_requests() >= 0); // if (m->device_resources.find(stream) == m->device_resources.end()) { // m->device_resources[stream] = new raft::device_resources(stream); // } diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 6c93a88b0..b4faf68cd 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -1953,16 +1953,21 @@ BatchConfig RequestManager::prepare_suffix_decoding_batch_config() { new_bc.tokensInfo[new_bc.num_tokens].token_id = request.tokens.back(); new_bc.num_tokens++; // add candidate tokens to the batch config + std::vector depths = {0}; // root has depth 0 + for (int i=1; i< request.suffix_decoding_best_parents.size(); i++) { + depths.push_back(depths[request.suffix_decoding_best_parents[i]] + 1); + } for (int i = 0; i < request.suffix_decoding_best_token_ids.size(); i++) { new_bc.tokensInfo[new_bc.num_tokens].request_index = request_index; new_bc.tokensInfo[new_bc.num_tokens].abs_index_in_request = request.tokens.size() + i; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = - request.tokens.size() + i; + request.tokens.size() + depths[i]; new_bc.tokensInfo[new_bc.num_tokens].token_id = request.suffix_decoding_best_token_ids[i]; new_bc.num_tokens++; } + new_bc.requestsInfo[request_index].num_tokens_in_batch = request.suffix_decoding_best_token_ids.size() + 1; // +1 for the bonus token @@ -1995,6 +2000,7 @@ BatchConfig RequestManager::prepare_suffix_decoding_batch_config() { if (profiling_requests[request.guid].llm_decoding_steps == 0) { profiling_requests[request.guid].start_decoding_time = start_time; } + } long long int end_time = Realm::Clock::current_time_in_microseconds(); profiling.tree_operation_step_times.push_back((double)(end_time - start_time) * 1e-3); @@ -2008,6 +2014,8 @@ BatchConfig RequestManager::prepare_suffix_decoding_batch_config() { profiling_requests[guid].speculation_end_timestamp = end_time; } + assert(new_bc.num_tokens <= max_tokens_per_batch); + if (verbose) { std::cout << "prepare_suffix_decoding_batch_config NEW batchconfig:" << std::endl; @@ -2829,9 +2837,9 @@ void RequestManager::get_verify_results_greedy( profiling.generated_tokens_per_step.push_back(total_nb_generated_tokens); } -// void print_ir_debug_info(InferenceResult const &llm_verify_result, int num_tokens, int topk) { -// std::cout << "Logits: "; -// for (int i=0; i