Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] master from ggerganov:master #195

Merged
merged 2 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@

if (ctx->mtl_device == nil) {
ctx->mtl_device = MTLCreateSystemDefaultDevice();
}

if (ctx->mtl_device) {
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];

Expand Down Expand Up @@ -99,8 +101,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
ctx->mtl_device_ref_count--;

if (ctx->mtl_device_ref_count == 0) {
[ctx->mtl_device release];
ctx->mtl_device = nil;
if (ctx->mtl_device) {
[ctx->mtl_device release];
ctx->mtl_device = nil;
}
}
}

Expand Down
241 changes: 139 additions & 102 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8432,74 +8432,33 @@ static enum ggml_status llama_graph_compute(
return status;
}

// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
static int llama_decode_impl(
llama_context & lctx,
llama_batch inp_batch) {

lctx.is_encoding = false;

if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}

// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);

const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens_all = batch.n_tokens;

static int llama_prepare_sbatch(
llama_context & lctx,
const llama_batch & batch,
uint32_t & n_outputs) {
const auto & model = lctx.model;
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;

GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
const uint32_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;

// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
}
}

GGML_ASSERT(n_tokens_all <= cparams.n_batch);

GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");

if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
lctx.n_queued_tokens += n_tokens_all;

auto & kv_self = lctx.kv_self;
llama_kv_slot_restorer kv_slot_restorer(kv_self);

const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = vocab.n_tokens();

uint32_t n_outputs = 0;
uint32_t n_outputs_prev = 0;

const auto n_ubatch = cparams.n_ubatch;

// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

lctx.embd_seq.clear();

// count outputs
Expand All @@ -8515,7 +8474,7 @@ static int llama_decode_impl(
}

lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* simple_split */ !lctx.kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all);

// reserve output buffer
Expand All @@ -8524,70 +8483,148 @@ static int llama_decode_impl(
return -2;
};

while (lctx.sbatch.n_tokens > 0) {
llama_ubatch ubatch;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(n_ubatch);
}
return 0;
}

static int llama_prepare_ubatch(
llama_context & lctx,
llama_kv_slot_restorer & kv_slot_restorer,
llama_ubatch & ubatch,
const uint32_t n_outputs,
const uint32_t n_tokens_all) {
GGML_ASSERT(lctx.sbatch.n_tokens > 0);

auto & kv_self = lctx.kv_self;
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;

// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

if (lctx.kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
} else {
ubatch = lctx.sbatch.split_simple(n_ubatch);
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
}
const uint32_t n_tokens = ubatch.n_tokens;
} else {
ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
}

// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;

if (n_outputs == n_tokens_all) {
n_outputs_new = n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
if (n_outputs == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += int32_t(ubatch.output[i] != 0);
}
}

// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
}

// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_cache_update(&lctx);

// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
kv_self.head = 0;
}

int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
return 1;
}
kv_slot_restorer.save(slot);

GGML_ASSERT(n_threads > 0);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = llama_kv_cache_get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}

// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_cache_update(&lctx);
return 0;
}

// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) {
kv_self.head = 0;
}
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - inp_batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
static int llama_decode_impl(
llama_context & lctx,
llama_batch inp_batch) {

const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
return 1;
}
kv_slot_restorer.save(slot);
lctx.is_encoding = false;

if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = llama_kv_cache_get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}

// temporarily allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
const llama_batch & batch = batch_allocr.batch;

const auto & model = lctx.model;
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;

if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
auto & kv_self = lctx.kv_self;
llama_kv_slot_restorer kv_slot_restorer(kv_self);

const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = vocab.n_tokens();

uint32_t n_outputs = 0;
uint32_t n_outputs_prev = 0;

{
const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
if (ret != 0) {
return ret;
}
}

while (lctx.sbatch.n_tokens > 0) {
llama_ubatch ubatch;
{
const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
if (ret != 0) {
return ret;
}
}

const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;

GGML_ASSERT(n_threads > 0);

//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);

ggml_backend_sched_reset(lctx.sched.get());
Expand Down Expand Up @@ -8640,7 +8677,7 @@ static int llama_decode_impl(

// update the kv ring buffer
{
kv_self.head += n_tokens;
kv_self.head += ubatch.n_tokens;

// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
Expand Down
Loading