-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
llama/ggml: add LLM training support
more compact progress bar refactor: llama_prepare_sbatch/ubatch llama_save_model_to_file gqa_mode arg for repeat_back llama_opt_param_filter ggml_graph_dup force_grads refactor ggml_opt, fix test-opt
- Loading branch information
1 parent
9c8dcef
commit a315cac
Showing
28 changed files
with
1,514 additions
and
490 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
set(TARGET llama-finetune) | ||
add_executable(${TARGET} finetune.cpp) | ||
install(TARGETS ${TARGET} RUNTIME) | ||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) | ||
target_compile_features(${TARGET} PRIVATE cxx_std_11) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# llama.cpp/examples/training | ||
|
||
This directory contains examples related to language model training using llama.cpp/GGML. | ||
So far finetuning is technically functional (for FP32 models) but the code is very much WIP. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#include "arg.h" | ||
#include "common.h" | ||
#include "log.h" | ||
#include "llama.h" | ||
|
||
#include <cmath> | ||
#include <cstdio> | ||
#include <cstring> | ||
#include <ctime> | ||
#include <vector> | ||
|
||
#if defined(_MSC_VER) | ||
#pragma warning(disable: 4244 4267) // possible loss of data | ||
#endif | ||
|
||
int main(int argc, char ** argv) { | ||
common_params params; | ||
|
||
params.logits_all = true; | ||
params.escape = false; | ||
|
||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) { | ||
return 1; | ||
} | ||
|
||
if (params.use_mmap) { | ||
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__); | ||
params.use_mmap = false; | ||
} | ||
if (params.cache_type_k == GGML_TYPE_F16) { | ||
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); | ||
params.cache_type_k = GGML_TYPE_F32; | ||
} | ||
if (params.cache_type_v == GGML_TYPE_F16) { | ||
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); | ||
params.cache_type_v = GGML_TYPE_F32; | ||
} | ||
|
||
common_init(); | ||
llama_backend_init(); | ||
llama_numa_init(params.numa); | ||
|
||
// load the model and apply lora adapter, if any | ||
common_init_result llama_init = common_init_from_params(params); | ||
llama_model_ptr & model = llama_init.model; | ||
llama_context_ptr & ctx = llama_init.context; | ||
|
||
if (model == NULL) { | ||
LOG_ERR("%s: unable to load model\n", __func__); | ||
return 1; | ||
} | ||
|
||
// print system information | ||
{ | ||
LOG_INF("\n"); | ||
LOG_INF("%s\n", common_params_get_system_info(params).c_str()); | ||
} | ||
|
||
constexpr float val_split = 0.05f; | ||
|
||
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true); | ||
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2); | ||
|
||
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr); | ||
optimizer_params.adamw.alpha = 1e-7f; // learning rate | ||
|
||
struct llama_opt_params lopt_params { | ||
/*n_ctx_train =*/ 0, | ||
/*param_filter =*/ llama_opt_param_filter_all, | ||
/*param_filter_ud =*/ nullptr, | ||
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params, | ||
/*get_opt_pars_ud =*/ &optimizer_params, | ||
}; | ||
llama_opt_init(ctx.get(), model.get(), lopt_params); | ||
|
||
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split); | ||
|
||
ggml_opt_result_t result_train = ggml_opt_result_init(); | ||
ggml_opt_result_t result_eval = ggml_opt_result_init(); | ||
|
||
for (int epoch = 0; epoch < 2; ++epoch) { | ||
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split, | ||
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar); | ||
fprintf(stderr, "\n"); | ||
|
||
ggml_opt_result_reset(result_train); | ||
ggml_opt_result_reset(result_eval); | ||
} | ||
ggml_opt_result_free(result_train); | ||
ggml_opt_result_free(result_eval); | ||
|
||
llama_save_model_to_file(model.get(), "finetuned-model.gguf"); | ||
|
||
llama_backend_free(); | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.