forked from ggml-org/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
sampling : support for llguidance grammars (ggml-org#10224)
* initial porting of previous LLG patch * update for new APIs * build: integrate llguidance as an external project * use '%llguidance' as marker to enable llg lark syntax * add some docs * clarify docs * code style fixes * remove llguidance.h from .gitignore * fix tests when llg is enabled * pass vocab not model to llama_sampler_init_llg() * copy test-grammar-integration.cpp to test-llguidance.cpp * clang fmt * fix ref-count bug * build and run test * gbnf -> lark syntax * conditionally include llguidance test based on LLAMA_LLGUIDANCE flag * rename llguidance test file to test-grammar-llguidance.cpp * add gh action for llg test * align tests with LLG grammar syntax and JSON Schema spec * llama_tokenizer() in fact requires valid utf8 * update llg * format file * add $LLGUIDANCE_LOG_LEVEL support * fix whitespace * fix warning * include <cmath> for INFINITY * add final newline * fail llama_sampler_init_llg() at runtime * Link gbnf_to_lark.py script; fix links; refer to llg docs for lexemes * simplify #includes * improve doc string for LLAMA_LLGUIDANCE * typo in merge * bump llguidance to 0.6.12
- Loading branch information
Showing
13 changed files
with
1,555 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
#include "sampling.h" | ||
#include "log.h" | ||
|
||
#ifdef LLAMA_USE_LLGUIDANCE | ||
|
||
# include "llguidance.h" | ||
# include <cmath> | ||
|
||
struct llama_sampler_llg { | ||
const llama_vocab * vocab; | ||
std::string grammar_kind; | ||
std::string grammar_data; | ||
LlgTokenizer * tokenizer; | ||
LlgConstraint * grammar; | ||
LlgMaskResult llg_res; | ||
bool has_llg_res; | ||
}; | ||
|
||
static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind, | ||
const char * grammar_data) { | ||
LlgConstraintInit cinit; | ||
llg_constraint_init_set_defaults(&cinit, tokenizer); | ||
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL"); | ||
if (log_level && *log_level) { | ||
cinit.log_stderr_level = atoi(log_level); | ||
} | ||
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data); | ||
if (llg_get_error(c)) { | ||
LOG_ERR("llg error: %s\n", llg_get_error(c)); | ||
llg_free_constraint(c); | ||
return nullptr; | ||
} | ||
return c; | ||
} | ||
|
||
static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) { | ||
return "llguidance"; | ||
} | ||
|
||
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) { | ||
auto * ctx = (llama_sampler_llg *) smpl->ctx; | ||
if (ctx->grammar) { | ||
LlgCommitResult res; | ||
llg_commit_token(ctx->grammar, token, &res); | ||
ctx->has_llg_res = false; | ||
} | ||
} | ||
|
||
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) { | ||
auto * ctx = (llama_sampler_llg *) smpl->ctx; | ||
if (ctx->grammar) { | ||
if (!ctx->has_llg_res) { | ||
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) { | ||
ctx->has_llg_res = true; | ||
} else { | ||
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar)); | ||
llg_free_constraint(ctx->grammar); | ||
ctx->grammar = nullptr; | ||
} | ||
} | ||
if (ctx->has_llg_res) { | ||
if (ctx->llg_res.is_stop) { | ||
for (size_t i = 0; i < cur_p->size; ++i) { | ||
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) { | ||
cur_p->data[i].logit = -INFINITY; | ||
} | ||
} | ||
} else { | ||
const uint32_t * mask = ctx->llg_res.sample_mask; | ||
for (size_t i = 0; i < cur_p->size; ++i) { | ||
auto token = cur_p->data[i].id; | ||
if ((mask[token / 32] & (1 << (token % 32))) == 0) { | ||
cur_p->data[i].logit = -INFINITY; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
static void llama_sampler_llg_reset(llama_sampler * smpl) { | ||
auto * ctx = (llama_sampler_llg *) smpl->ctx; | ||
if (!ctx->grammar) { | ||
return; | ||
} | ||
|
||
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str()); | ||
llg_free_constraint(ctx->grammar); | ||
ctx->grammar = grammar_new; | ||
ctx->has_llg_res = false; | ||
} | ||
|
||
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { | ||
const auto * ctx = (const llama_sampler_llg *) smpl->ctx; | ||
|
||
auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr); | ||
|
||
// copy the state | ||
{ | ||
auto * result_ctx = (llama_sampler_llg *) result->ctx; | ||
|
||
if (ctx->grammar) { | ||
result_ctx->grammar_kind = ctx->grammar_kind; | ||
result_ctx->grammar_data = ctx->grammar_data; | ||
result_ctx->grammar = llg_clone_constraint(ctx->grammar); | ||
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer); | ||
} | ||
} | ||
|
||
return result; | ||
} | ||
|
||
static void llama_sampler_llg_free(llama_sampler * smpl) { | ||
const auto * ctx = (llama_sampler_llg *) smpl->ctx; | ||
|
||
if (ctx->grammar) { | ||
llg_free_constraint(ctx->grammar); | ||
llg_free_tokenizer(ctx->tokenizer); | ||
} | ||
|
||
delete ctx; | ||
} | ||
|
||
static llama_sampler_i llama_sampler_llg_i = { | ||
/* .name = */ llama_sampler_llg_name, | ||
/* .accept = */ llama_sampler_llg_accept_impl, | ||
/* .apply = */ llama_sampler_llg_apply, | ||
/* .reset = */ llama_sampler_llg_reset, | ||
/* .clone = */ llama_sampler_llg_clone, | ||
/* .free = */ llama_sampler_llg_free, | ||
}; | ||
|
||
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len, | ||
uint32_t * output_tokens, size_t output_tokens_len) { | ||
const llama_vocab * vocab = (const llama_vocab *) user_data; | ||
int r = 0; | ||
try { | ||
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false, | ||
true); | ||
} catch (const std::exception & e) { | ||
GGML_ABORT("llama_tokenize failed: %s\n", e.what()); | ||
} | ||
if (r < 0) { | ||
return -r; | ||
} | ||
return r; | ||
} | ||
|
||
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) { | ||
// TODO store the tokenizer in the vocab somehow | ||
static const llama_vocab * vocab_cache; | ||
static LlgTokenizer * tokenizer_cache; | ||
|
||
if (vocab_cache == vocab) { | ||
return llg_clone_tokenizer(tokenizer_cache); | ||
} | ||
|
||
auto tok_eos = llama_vocab_eot(vocab); | ||
if (tok_eos == LLAMA_TOKEN_NULL) { | ||
tok_eos = llama_vocab_eos(vocab); | ||
} | ||
|
||
size_t vocab_size = llama_vocab_n_tokens(vocab); | ||
|
||
auto token_lens = new uint32_t[vocab_size]; | ||
// we typically have ~7 bytes per token; let's go on the safe side here | ||
auto token_bytes_size = vocab_size * 16 + 1024 * 1024; | ||
auto token_bytes = new uint8_t[token_bytes_size]; | ||
|
||
size_t offset = 0; | ||
for (size_t i = 0; i < vocab_size; i++) { | ||
size_t max_token = 1024; | ||
if (token_bytes_size - offset < max_token) { | ||
GGML_ABORT("token_bytes buffer too small\n"); | ||
} | ||
|
||
llama_token token = i; | ||
auto dp = (char *) token_bytes + offset; | ||
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false); | ||
if (size < 0) { | ||
GGML_ABORT("llama_detokenize failed\n"); | ||
} | ||
if (size == 0) { | ||
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true); | ||
if (size < 0) { | ||
GGML_ABORT("llama_detokenize failed\n"); | ||
} | ||
if (size != 0) { | ||
*dp = '\xff'; // special token prefix marker | ||
size += 1; | ||
} | ||
} | ||
|
||
token_lens[i] = size; | ||
offset += size; | ||
} | ||
|
||
LlgTokenizerInit tinit = { | ||
/* .vocab_size = */ (uint32_t) vocab_size, | ||
/* .tok_eos = */ (uint32_t) tok_eos, | ||
/* .token_lens = */ token_lens, | ||
/* .token_bytes = */ token_bytes, | ||
/* .tokenizer_json = */ nullptr, | ||
/* .tokenize_assumes_string = */ true, | ||
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn, | ||
/* .use_approximate_greedy_tokenize_fn = */ false, | ||
/* .tokenize_user_data = */ vocab, | ||
}; | ||
|
||
char error_buffer[1024]; | ||
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer)); | ||
|
||
delete[] token_bytes; | ||
delete[] token_lens; | ||
|
||
if (tokenizer == nullptr) { | ||
LOG_ERR("llg tokenizer error: %s\n", error_buffer); | ||
return tokenizer; | ||
} | ||
|
||
if (tokenizer_cache) { | ||
llg_free_tokenizer(tokenizer_cache); | ||
} | ||
vocab_cache = vocab; | ||
tokenizer_cache = tokenizer; | ||
|
||
return llg_clone_tokenizer(tokenizer_cache); | ||
} | ||
|
||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, | ||
const char * grammar_data) { | ||
auto * ctx = new llama_sampler_llg; | ||
|
||
if (grammar_kind != nullptr && grammar_kind[0] != '\0') { | ||
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab); | ||
*ctx = { | ||
/* .vocab = */ vocab, | ||
/* .grammar_kind = */ grammar_kind, | ||
/* .grammar_data = */ grammar_data, | ||
/* .tokenizer = */ tokenizer, | ||
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data), | ||
/* .llg_res = */ {}, | ||
/* .has_llg_res = */ false, | ||
}; | ||
} else { | ||
*ctx = { | ||
/* .vocab = */ vocab, | ||
/* .grammar_kind = */ {}, | ||
/* .grammar_data = */ {}, | ||
/* .tokenizer = */ nullptr, | ||
/* .grammar = */ nullptr, | ||
/* .llg_res = */ {}, | ||
/* .has_llg_res = */ false, | ||
}; | ||
} | ||
|
||
return new llama_sampler{ | ||
/* .iface = */ &llama_sampler_llg_i, | ||
/* .ctx = */ ctx, | ||
}; | ||
} | ||
|
||
#else | ||
|
||
llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) { | ||
LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); | ||
return nullptr; | ||
} | ||
|
||
#endif // LLAMA_USE_LLGUIDANCE |
Oops, something went wrong.