diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index ea06fcdbf76ee..bc9f6fa682f96 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -36,6 +36,8 @@ class IMatrixCollector { void set_parameters(StatParams&& params) { m_params = std::move(params); } bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); void save_imatrix() const; + bool load_imatrix(const char * file_name, bool add); + static bool load_imatrix(const char * file_name, std::unordered_map& imatrix); private: std::unordered_map m_stats; StatParams m_params; @@ -189,6 +191,57 @@ void IMatrixCollector::save_imatrix(const char * fname) const { } } +bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_map& imatrix_data) { + std::ifstream in(imatrix_file, std::ios::binary); + if (!in) { + printf("%s: failed to open %s\n",__func__,imatrix_file); + return false; + } + int n_entries; + in.read((char*)&n_entries, sizeof(n_entries)); + if (in.fail() || n_entries < 1) { + printf("%s: no data in file %s\n", __func__, imatrix_file); + return false; + } + for (int i = 0; i < n_entries; ++i) { + int len; in.read((char *)&len, sizeof(len)); + std::vector name_as_vec(len+1); + in.read((char *)name_as_vec.data(), len); + if (in.fail()) { + printf("%s: failed reading name for entry %d from %s\n",__func__,i+1,imatrix_file); + return false; + } + name_as_vec[len] = 0; + std::string name{name_as_vec.data()}; + auto& e = imatrix_data[std::move(name)]; + int ncall; + in.read((char*)&ncall, sizeof(ncall)); + int nval; + in.read((char *)&nval, sizeof(nval)); + if (in.fail() || nval < 1) { + printf("%s: failed reading number of values for entry %d\n",__func__,i); + imatrix_data = {}; + return false; + } + e.values.resize(nval); + in.read((char*)e.values.data(), nval*sizeof(float)); + if (in.fail()) { + printf("%s: failed reading data for entry %d\n",__func__,i); + imatrix_data = {}; + return false; + } + e.ncall = ncall; + } + return true; +} + +bool IMatrixCollector::load_imatrix(const char * file_name, bool add) { + if (!add) { + m_stats.clear(); + } + return load_imatrix(file_name, m_stats); +} + static IMatrixCollector g_collector; static bool ik_collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { @@ -269,7 +322,7 @@ static void process_logits( } } -static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl) { +static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) { const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const int n_ctx = llama_n_ctx(ctx); @@ -282,6 +335,15 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool auto tim2 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); + if (from_chunk > 0) { + if (size_t((from_chunk + 2)*n_ctx) >= tokens.size()) { + fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, from_chunk); + return false; + } + fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk*n_ctx); + tokens.erase(tokens.begin(), tokens.begin() + from_chunk*n_ctx); + } + if (int(tokens.size()) < 2*n_ctx) { fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx, n_ctx); @@ -402,7 +464,10 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool int main(int argc, char ** argv) { StatParams sparams; + std::string prev_result_file; + std::string combine_files; bool compute_ppl = true; + int from_chunk = 0; std::vector args; args.push_back(argv[0]); int iarg = 1; @@ -423,6 +488,13 @@ int main(int argc, char ** argv) { compute_ppl = false; } else if (arg == "--keep-imatrix") { sparams.keep_every = std::stoi(argv[++iarg]); + } else if (arg == "--continue-from") { + prev_result_file = argv[++iarg]; + } else if (arg == "--combine") { + combine_files = argv[++iarg]; + } + else if (arg == "--from-chunk") { + from_chunk = std::stoi(argv[++iarg]); } else { args.push_back(argv[iarg]); } @@ -436,14 +508,50 @@ int main(int argc, char ** argv) { } } + g_collector.set_parameters(std::move(sparams)); + + if (!combine_files.empty()) { + std::vector files; + size_t pos = 0; + while (true) { + auto new_pos = combine_files.find(',', pos); + if (new_pos != std::string::npos) { + files.emplace_back(combine_files.substr(pos, new_pos - pos)); + pos = new_pos + 1; + } else { + files.emplace_back(combine_files.substr(pos)); + break; + } + } + if (files.size() < 2) { + fprintf(stderr, "You must provide at least two comma separated files to use --combine\n"); + return 1; + } + printf("Combining the following %d files\n", int(files.size())); + for (auto& file : files) { + printf(" %s\n", file.c_str()); + if (!g_collector.load_imatrix(file.c_str(), true)) { + fprintf(stderr, "Failed to load %s\n", file.c_str()); + return 1; + } + } + g_collector.save_imatrix(); + return 0; + } + + if (!prev_result_file.empty()) { + if (!g_collector.load_imatrix(prev_result_file.c_str(), false)) { + fprintf(stderr, "=============== Failed to load %s\n", prev_result_file.c_str()); + return 1; + } + } + gpt_params params; params.n_batch = 512; if (!gpt_params_parse(args.size(), args.data(), params)) { return 1; } - g_collector.set_parameters(std::move(sparams)); - params.logits_all = true; params.n_batch = std::min(params.n_batch, params.n_ctx); @@ -495,7 +603,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s\n", get_system_info(params).c_str()); } - bool OK = compute_imatrix(ctx, params, compute_ppl); + bool OK = compute_imatrix(ctx, params, compute_ppl, from_chunk); if (!OK) { return 1; }