-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
llama/ggml: add LLM training support
more compact progress bar refactor: llama_prepare_sbatch/ubatch llama_save_model_to_file gqa_mode arg for repeat_back llama_opt_param_filter ggml_graph_dup force_grads refactor ggml_opt, fix test-opt
- Loading branch information
1 parent
a5203b4
commit c255573
Showing
26 changed files
with
1,288 additions
and
333 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
set(TARGET llama-finetune) | ||
add_executable(${TARGET} finetune.cpp) | ||
install(TARGETS ${TARGET} RUNTIME) | ||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) | ||
target_compile_features(${TARGET} PRIVATE cxx_std_11) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# llama.cpp/examples/training | ||
|
||
This directory contains examples related to language model training using llama.cpp/GGML. | ||
So far finetuning is technically functional (for FP32 models and limited hardware setups) but the code is very much WIP. | ||
Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory. | ||
**For CPU training, compile llama.cpp without any additional backends such as CUDA.** | ||
**For CUDA training, use the maximum number of GPU layers.** | ||
|
||
Proof of concept: | ||
|
||
``` sh | ||
export model_name=llama_3.2-1b && export quantization=f32 | ||
./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512 | ||
./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf | ||
``` | ||
|
||
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#include "arg.h" | ||
#include "common.h" | ||
#include "log.h" | ||
#include "llama.h" | ||
|
||
#include <cmath> | ||
#include <cstdio> | ||
#include <cstring> | ||
#include <ctime> | ||
#include <vector> | ||
|
||
#if defined(_MSC_VER) | ||
#pragma warning(disable: 4244 4267) // possible loss of data | ||
#endif | ||
|
||
int main(int argc, char ** argv) { | ||
common_params params; | ||
|
||
params.logits_all = true; | ||
params.escape = false; | ||
|
||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) { | ||
return 1; | ||
} | ||
|
||
if (params.use_mmap) { | ||
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__); | ||
params.use_mmap = false; | ||
} | ||
if (params.cache_type_k == GGML_TYPE_F16) { | ||
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); | ||
params.cache_type_k = GGML_TYPE_F32; | ||
} | ||
if (params.cache_type_v == GGML_TYPE_F16) { | ||
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); | ||
params.cache_type_v = GGML_TYPE_F32; | ||
} | ||
|
||
common_init(); | ||
llama_backend_init(); | ||
llama_numa_init(params.numa); | ||
|
||
// load the model and apply lora adapter, if any | ||
common_init_result llama_init = common_init_from_params(params); | ||
llama_model_ptr & model = llama_init.model; | ||
llama_context_ptr & ctx = llama_init.context; | ||
|
||
if (model == NULL) { | ||
LOG_ERR("%s: unable to load model\n", __func__); | ||
return 1; | ||
} | ||
|
||
// print system information | ||
{ | ||
LOG_INF("\n"); | ||
LOG_INF("%s\n", common_params_get_system_info(params).c_str()); | ||
} | ||
|
||
constexpr float val_split = 0.05f; | ||
|
||
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true); | ||
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2); | ||
|
||
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr); | ||
optimizer_params.adamw.alpha = 1e-7f; // learning rate | ||
|
||
struct llama_opt_params lopt_params { | ||
/*n_ctx_train =*/ 0, | ||
/*param_filter =*/ llama_opt_param_filter_all, | ||
/*param_filter_ud =*/ nullptr, | ||
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params, | ||
/*get_opt_pars_ud =*/ &optimizer_params, | ||
}; | ||
llama_opt_init(ctx.get(), model.get(), lopt_params); | ||
|
||
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split); | ||
|
||
ggml_opt_result_t result_train = ggml_opt_result_init(); | ||
ggml_opt_result_t result_eval = ggml_opt_result_init(); | ||
|
||
for (int epoch = 0; epoch < 2; ++epoch) { | ||
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split, | ||
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar); | ||
fprintf(stderr, "\n"); | ||
|
||
ggml_opt_result_reset(result_train); | ||
ggml_opt_result_reset(result_eval); | ||
} | ||
ggml_opt_result_free(result_train); | ||
ggml_opt_result_free(result_eval); | ||
|
||
llama_model_save_to_file(model.get(), "finetuned-model.gguf"); | ||
|
||
llama_backend_free(); | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.