Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sync llama.cpp #100

Merged
merged 3 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@ cmake_minimum_required(VERSION 3.10)

project(llama.rn)

set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 17)
set(RNLLAMA_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../cpp)

include_directories(${RNLLAMA_LIB_DIR})

set(
SOURCE_FILES
${RNLLAMA_LIB_DIR}/ggml.c
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
${RNLLAMA_LIB_DIR}/ggml-alloc.c
${RNLLAMA_LIB_DIR}/ggml-backend.cpp
${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
${RNLLAMA_LIB_DIR}/ggml-cpu.c
${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.c
${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.cpp
${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c
${RNLLAMA_LIB_DIR}/ggml-cpu-traits.cpp
${RNLLAMA_LIB_DIR}/ggml-opt.cpp
${RNLLAMA_LIB_DIR}/ggml-threading.cpp
${RNLLAMA_LIB_DIR}/ggml-quants.c
Expand All @@ -32,6 +32,8 @@ set(
${RNLLAMA_LIB_DIR}/sgemm.cpp
${RNLLAMA_LIB_DIR}/common.cpp
${RNLLAMA_LIB_DIR}/rn-llama.hpp
${RNLLAMA_LIB_DIR}/amx/amx.cpp
${RNLLAMA_LIB_DIR}/amx/mmq.cpp
${CMAKE_SOURCE_DIR}/jni-utils.h
${CMAKE_SOURCE_DIR}/jni.cpp
)
Expand All @@ -47,7 +49,7 @@ function(build_library target_name cpu_flags)

target_link_libraries(${target_name} ${LOG_LIB} android)

target_compile_options(${target_name} PRIVATE -pthread ${cpu_flags})
target_compile_options(${target_name} PRIVATE -DLM_GGML_USE_CPU -pthread ${cpu_flags})

if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
target_compile_options(${target_name} PRIVATE -DRNLLAMA_ANDROID_ENABLE_LOGGING)
Expand Down
3 changes: 0 additions & 3 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("mirostat_tau") ? (float) params.getDouble("mirostat_tau") : 5.00f,
// float mirostat_eta,
params.hasKey("mirostat_eta") ? (float) params.getDouble("mirostat_eta") : 0.10f,
// boolean penalize_nl,
params.hasKey("penalize_nl") ? params.getBoolean("penalize_nl") : false,
// int top_k,
params.hasKey("top_k") ? params.getInt("top_k") : 40,
// float top_p,
Expand Down Expand Up @@ -463,7 +461,6 @@ protected static native WritableMap doCompletion(
float mirostat,
float mirostat_tau,
float mirostat_eta,
boolean penalize_nl,
int top_k,
float top_p,
float min_p,
Expand Down
14 changes: 6 additions & 8 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ Java_com_rnllama_LlamaContext_initContext(

const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr);
const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr);
defaultParams.cache_type_k = cache_type_k_chars;
defaultParams.cache_type_v = cache_type_v_chars;
defaultParams.cache_type_k = rnllama::kv_cache_type_from_str(cache_type_k_chars);
defaultParams.cache_type_v = rnllama::kv_cache_type_from_str(cache_type_v_chars);

defaultParams.use_mlock = use_mlock;
defaultParams.use_mmap = use_mmap;
Expand Down Expand Up @@ -553,7 +553,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
jfloat mirostat,
jfloat mirostat_tau,
jfloat mirostat_eta,
jboolean penalize_nl,
jint top_k,
jfloat top_p,
jfloat min_p,
Expand All @@ -579,17 +578,17 @@ Java_com_rnllama_LlamaContext_doCompletion(
//llama_reset_timings(llama->ctx);

llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
llama->params.sparams.seed = (seed == -1) ? time(NULL) : seed;
llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed;

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;

llama->params.n_predict = n_predict;
llama->params.sparams.ignore_eos = ignore_eos;
llama->params.sampling.ignore_eos = ignore_eos;

auto & sparams = llama->params.sparams;
auto & sparams = llama->params.sampling;
sparams.temp = temperature;
sparams.penalty_last_n = penalty_last_n;
sparams.penalty_repeat = penalty_repeat;
Expand All @@ -598,7 +597,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
sparams.mirostat = mirostat;
sparams.mirostat_tau = mirostat_tau;
sparams.mirostat_eta = mirostat_eta;
sparams.penalize_nl = penalize_nl;
sparams.top_k = top_k;
sparams.top_p = top_p;
sparams.min_p = min_p;
Expand Down Expand Up @@ -714,7 +712,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
auto tokenResult = createWriteableMap(env);
putString(env, tokenResult, "token", to_send.c_str());

if (llama->params.sparams.n_probs > 0) {
if (llama->params.sampling.n_probs > 0) {
const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
Expand Down
220 changes: 220 additions & 0 deletions cpp/amx/amx.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
#include "amx.h"
#include "common.h"
#include "mmq.h"
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
#include "ggml-impl.h"
#include "ggml-cpu.h"
#include "ggml-cpu-traits.h"

#if defined(__gnu_linux__)
#include <sys/syscall.h>
#include <unistd.h>
#endif

#include <cstdlib>
#include <cstring>
#include <memory>

#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)

// AMX type_trais
namespace ggml::cpu::amx {
class tensor_traits : public ggml::cpu::tensor_traits {
bool work_size(int /* n_threads */, const struct lm_ggml_tensor * op, size_t & size) override {
size = lm_ggml_backend_amx_desired_wsize(op);
return true;
}

bool compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) override {
if (op->op == LM_GGML_OP_MUL_MAT) {
lm_ggml_backend_amx_mul_mat(params, op);
return true;
}
return false;
}
};

static ggml::cpu::tensor_traits * get_tensor_traits(lm_ggml_backend_buffer_t, struct lm_ggml_tensor *) {
static tensor_traits traits;
return &traits;
}
} // namespace ggml::cpu::amx

// AMX buffer interface
static void lm_ggml_backend_amx_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) {
free(buffer->context);
}

static void * lm_ggml_backend_amx_buffer_get_base(lm_ggml_backend_buffer_t buffer) {
return (void *) (buffer->context);
}

static void lm_ggml_backend_amx_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) {
tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor);

LM_GGML_UNUSED(buffer);
}

static void lm_ggml_backend_amx_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor,
uint8_t value, size_t offset, size_t size) {
memset((char *) tensor->data + offset, value, size);

LM_GGML_UNUSED(buffer);
}

static void lm_ggml_backend_amx_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor,
const void * data, size_t offset, size_t size) {
if (qtype_has_amx_kernels(tensor->type)) {
LM_GGML_LOG_DEBUG("%s: amx repack tensor %s of type %s\n", __func__, tensor->name, lm_ggml_type_name(tensor->type));
lm_ggml_backend_amx_convert_weight(tensor, data, offset, size);
} else {
memcpy((char *) tensor->data + offset, data, size);
}

LM_GGML_UNUSED(buffer);
}

/*
// need to figure what we need to do with buffer->extra.
static void lm_ggml_backend_amx_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
LM_GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
memcpy(data, (const char *)tensor->data + offset, size);

LM_GGML_UNUSED(buffer);
}

static bool lm_ggml_backend_amx_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) {
if (lm_ggml_backend_buffer_is_host(src->buffer)) {
if (qtype_has_amx_kernels(src->type)) {
lm_ggml_backend_amx_convert_weight(dst, src->data, 0, lm_ggml_nbytes(dst));
} else {
memcpy(dst->data, src->data, lm_ggml_nbytes(src));
}
return true;
}
return false;

LM_GGML_UNUSED(buffer);
}
*/

static void lm_ggml_backend_amx_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) {
memset(buffer->context, value, buffer->size);
}

static lm_ggml_backend_buffer_i lm_ggml_backend_amx_buffer_interface = {
/* .free_buffer = */ lm_ggml_backend_amx_buffer_free_buffer,
/* .get_base = */ lm_ggml_backend_amx_buffer_get_base,
/* .init_tensor = */ lm_ggml_backend_amx_buffer_init_tensor,
/* .memset_tensor = */ lm_ggml_backend_amx_buffer_memset_tensor,
/* .set_tensor = */ lm_ggml_backend_amx_buffer_set_tensor,
/* .get_tensor = */ nullptr,
/* .cpy_tensor = */ nullptr,
/* .clear = */ lm_ggml_backend_amx_buffer_clear,
/* .reset = */ nullptr,
};

static const char * lm_ggml_backend_amx_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) {
return "AMX";

LM_GGML_UNUSED(buft);
}

static lm_ggml_backend_buffer_t lm_ggml_backend_amx_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) {
void * data = lm_ggml_aligned_malloc(size);
if (data == NULL) {
fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
return NULL;
}

return lm_ggml_backend_buffer_init(buft, lm_ggml_backend_amx_buffer_interface, data, size);
}

static size_t lm_ggml_backend_amx_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) {
return TENSOR_ALIGNMENT;

LM_GGML_UNUSED(buft);
}

namespace ggml::cpu::amx {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(lm_ggml_backend_dev_t, const struct lm_ggml_tensor * op) override {
// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct lm_ggml_tensor * t) {
return lm_ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};

if (op->op == LM_GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
op->src[0]->buffer && op->src[0]->buffer->buft == lm_ggml_backend_amx_buffer_type() &&
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == LM_GGML_TYPE_F16))) {
// src1 must be host buffer
if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
// src1 must be float32
if (op->src[1]->type == LM_GGML_TYPE_F32) {
return true;
}
}
return false;
}

ggml::cpu::tensor_traits * get_tensor_traits(const struct lm_ggml_tensor * op) override {
if (op->op == LM_GGML_OP_MUL_MAT && op->src[0]->buffer &&
op->src[0]->buffer->buft == lm_ggml_backend_amx_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}

return nullptr;
}
};
} // namespace ggml::cpu::amx

static size_t lm_ggml_backend_amx_buffer_type_get_alloc_size(lm_ggml_backend_buffer_type_t buft, const lm_ggml_tensor * tensor) {
return lm_ggml_backend_amx_get_alloc_size(tensor);

LM_GGML_UNUSED(buft);
}

#define ARCH_GET_XCOMP_PERM 0x1022
#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILECFG 17
#define XFEATURE_XTILEDATA 18

static bool lm_ggml_amx_init() {
#if defined(__gnu_linux__)
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
fprintf(stderr, "AMX is not ready to be used!\n");
return false;
}
return true;
#elif defined(_WIN32)
return true;
#endif
}

lm_ggml_backend_buffer_type_t lm_ggml_backend_amx_buffer_type() {
static struct lm_ggml_backend_buffer_type lm_ggml_backend_buffer_type_amx = {
/* .iface = */ {
/* .get_name = */ lm_ggml_backend_amx_buffer_type_get_name,
/* .alloc_buffer = */ lm_ggml_backend_amx_buffer_type_alloc_buffer,
/* .get_alignment = */ lm_ggml_backend_amx_buffer_type_get_alignment,
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
/* .get_alloc_size = */ lm_ggml_backend_amx_buffer_type_get_alloc_size,
/* .is_host = */ nullptr,
},
/* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0),
/* .context = */ new ggml::cpu::amx::extra_buffer_type(),
};

if (!lm_ggml_amx_init()) {
return nullptr;
}

return &lm_ggml_backend_buffer_type_amx;
}

#endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)
8 changes: 8 additions & 0 deletions cpp/amx/amx.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include "ggml-backend.h"
#include "ggml-cpu-impl.h"

// GGML internal header

#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
lm_ggml_backend_buffer_type_t lm_ggml_backend_amx_buffer_type(void);
#endif
Loading
Loading