Skip to content

Commit

Permalink
export and run llama in C++
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Jan 9, 2025
1 parent b8f0cac commit 761b2c9
Show file tree
Hide file tree
Showing 14 changed files with 8,628 additions and 4 deletions.
1 change: 1 addition & 0 deletions llms/export/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build/
42 changes: 42 additions & 0 deletions llms/export/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
cmake_minimum_required(VERSION 3.27)

project(mlxlm LANGUAGES CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)

add_library(mlxlm)
target_link_libraries(mlxlm PUBLIC mlx)

include(FetchContent)

FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlxlm PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)

target_sources(mlxlm
PRIVATE
mlxlm.cpp
tokenizer.cpp
unicode.cpp
unicode_data.cpp)

add_executable(main main.cpp)
target_link_libraries(main PRIVATE mlxlm)

add_executable(test test.cpp)
target_link_libraries(test PRIVATE mlxlm)
34 changes: 34 additions & 0 deletions llms/export/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Export LLMs to C++

Export language model inference from Python to run directly in C++.

To run, first install the requirements:

```bash
pip install -U mlx-lm
```

Then generate text from Python with:

```bash
python export.py generate "How tall is K2?"
```

To export the generation function run:

```bash
python export.py export
```

Then build the C++ code (requires CMake):

```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```

And run the generation from C++ with:

```bash
./build/main lama3.1-instruct-4bit "How tall is K2?"
```
171 changes: 171 additions & 0 deletions llms/export/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import time
from pathlib import Path

import fire
import mlx.core as mx
from mlx_lm import load


class ExportableCache:

def __init__(self, keys=None, values=None, offset=0):
self.offset = offset
self.keys = keys
self.values = values

def update_and_fetch(self, keys, values):
if self.keys is not None:
self.keys = mx.slice_update(self.keys, keys, self.offset, axes=(2,))
self.values = mx.slice_update(self.values, values, self.offset, axes=(2,))
else:
self.keys = keys
self.values = values
return self.keys, self.values

@property
def state(self):
return self.keys, self.values


def expand(cache, mask=None, cache_step_size=256):
cache_size = cache[0].shape[-2]
new_size = cache_step_size * ((cache_size + cache_step_size) // cache_step_size)

def expand_kv(x):
B, n_heads, _, head_dim = x.shape
new_x = mx.zeros((B, n_heads, new_size, head_dim), x.dtype)
new_x[..., : x.shape[2], :] = x
return new_x

cache = [expand_kv(c) for c in cache]
if mask is None:
mask = mx.full(new_size, False)
mask[:cache_size] = True
else:
mask = mx.concatenate([mask, mx.full(cache_step_size, False)])
return cache, mask


def causal_mask(N):
idx = mx.arange(N)
return idx[:, None] >= idx


def step(model, y, *state):
mask = state[-1]
if len(state) > 1:
cache, offset = state[:-2], state[-2]
cache = [
ExportableCache(keys, values, offset)
for keys, values in zip(cache[::2], cache[1::2])
]
else:
cache = [ExportableCache() for i in range(len(model.model.layers))]
logits = model(y, cache=cache, mask=mask)
cache = [y for x in cache for y in x.state]
return logits, *cache


def generate_step(prompt, model, max_tokens):
mx.eval(model)

compiled_step = mx.compile(lambda *args: step(model, *args), shapeless=True)

def _step(*args):
logits, *cache = compiled_step(*args)
return mx.argmax(logits[:, -1], axis=-1), *cache

y, *cache = _step(prompt, causal_mask(prompt.size))
mx.async_eval(y)
offset = mx.array(prompt.size, mx.uint32)
cache, mask = expand(cache)
n = 0
while True:
if n < max_tokens - 1:
if mask.size <= (prompt.size + n):
cache, mask = expand(cache, mask)
mask[prompt.size + n] = True
next_y, *cache = _step(y[None], *cache, offset, mask)
mx.async_eval(next_y)
offset += 1
n += 1
yield y.item()
if n == max_tokens:
break
y = next_y


def export(
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
path="llama3.1-instruct-4bit",
):
model, tokenizer = load(model)

mx.eval(model)

tokenizer.save_pretrained(path)

_step = lambda *args: step(model, *args)

# Make example inputs
y_prompt = mx.array([[0, 0]], mx.uint32)
y_gen = mx.array([[0]], mx.uint32)
offset = mx.array([0], mx.uint32)

mask = causal_mask(y_prompt.size)
_, *cache = _step(y_prompt, mask)

model_path = str(Path(path) / "model.mlxfn")
with mx.exporter(model_path, _step, shapeless=True) as exporter:
exporter(y_prompt, mask)
cache, mask = expand(cache)
exporter(y_gen, *cache, offset, mask)


def generate(
prompt,
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
max_tokens=128,
):
print("[INFO] Loading model from disk.")
model, tokenizer = load(model)
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="mlx",
)

print("[INFO] Starting generation...")
tic = time.time()
tokens = []

detokenizer = tokenizer.detokenizer
detokenizer.reset()

for n, token in enumerate(generate_step(prompt, model, max_tokens)):
if n == 0:
prompt_tps = prompt.size / (time.time() - tic)
tic = time.time()

if token in tokenizer.eos_token_ids:
break
detokenizer.add_token(token)
print(detokenizer.last_segment, end="", flush=True)

detokenizer.finalize()
print(detokenizer.last_segment, flush=True)
gen_tps = (n + 1) / (time.time() - tic)
peak_memory = mx.metal.get_peak_memory() / 1e9
print("=" * 10)
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
print(f"Peak RAM: {peak_memory:.3f} GB")


if __name__ == "__main__":
fire.Fire(
{
"generate": generate,
"export": export,
}
)
18 changes: 18 additions & 0 deletions llms/export/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright © 2024 Apple Inc.

#include <iostream>

#include "mlxlm.h"

int main(int argc, char *argv[]) {
if (argc < 3) {
std::cerr << "Must provide the model path and prompt." << std::endl;
return 1;
}
auto path = std::string(argv[1]);
auto prompt = std::string(argv[2]);

auto model = load_model(path + "/model.mlxfn");
auto tokenizer = load_tokenizer(path);
generate(model, tokenizer, prompt);
}
119 changes: 119 additions & 0 deletions llms/export/mlxlm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright © 2024 Apple Inc.

#include <chrono>
#include <iomanip>
#include <iostream>

#include "mlxlm.h"

namespace mx = mlx::core;

#define seconds(x) \
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e9)
#define time_now() std::chrono::high_resolution_clock::now()

// Maybe compile
std::function<mx::Args(mx::Args)> load_model(const std::string& path) {
return mx::compile(mx::import_function(path), /* shapeless = */ true);
}

// Maybe make tokenizer virtual
BPETokenizer load_tokenizer(const std::string& path) {
return BPETokenizer(path);
}

void generate(
const std::function<mx::Args(mx::Args)>& model,
const BPETokenizer& tokenizer,
const std::string& prompt,
int max_tokens /* = 256 */) {

auto prompt_tokens = tokenizer.encode(prompt);
int prompt_size = prompt_tokens.size();
auto y = mx::array(prompt_tokens.data(), {1, prompt_size}, mx::uint32);

auto create_causal_mask = [](int N) {
auto indices = mx::arange(N);
return mx::expand_dims(indices, 1) >= indices;
};

// Helper to expand the cache and mask
auto expand = [](auto& args, auto& mask) {
constexpr int cache_step_size = 256;
int cache_size = args[1].shape(-2);
int new_size = cache_step_size * ((cache_size + cache_step_size) / cache_step_size);
for (auto it = args.begin() + 1; it != args.end(); ++it) {
auto& x = *it;
auto shape = x.shape();
shape[2] = new_size;
auto new_x = mx::zeros(shape, x.dtype());
shape[2] = cache_size;
*it = mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape));
}
mask = mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size});
};

auto tic = time_now();
float prompt_time;
int n = 0;

mx::Args args;
{
args = model({y, create_causal_mask(y.size())});
auto logits = args[0];
logits = slice(logits, {0, -1, 0}, logits.shape());
y = argmax(logits, -1);
async_eval(y);
}

auto offset = mx::array(prompt_size, mx::uint32);
std::vector<int> tokens;

auto mask = mx::full({prompt_size}, true);
expand(args, mask);

for (; n < max_tokens; ++n) {
// Start next token decoding if needed
if (n < max_tokens - 1) {
args[0] = y;
auto m = prompt_size + n;
if (mask.size() <= m) {
expand(args, mask);
}
mask = mx::slice_update(mask, mx::array(true), {m}, {m + 1});
args.push_back(offset);
args.push_back(mask);
args = model(args);
args[0] = argmax(args[0], -1);
offset = offset + 1u;
async_eval(args[0]);
}

auto token = y.item<int>();
if (token == tokenizer.eos_token_id()) {
break;
}
tokens.push_back(token);
auto [result, complete] = tokenizer.try_decode(tokens);
if (complete) {
std::cout << result << std::flush;
tokens.clear();
}
if (n == 0) {
prompt_time = seconds(time_now() - tic);
tic = time_now();
}

if (n < max_tokens - 1) {
y = args[0];
}
}
auto result = tokenizer.decode(tokens);
std::cout << result << std::flush;

auto gen_time = seconds(time_now() - tic);
std::cout << std::endl;
std::cout << std::setprecision(5) << "Prompt toks/sec "
<< prompt_size / prompt_time << "\nGeneration toks/sec "
<< (n + 1) / gen_time << std::endl;
}
Loading

0 comments on commit 761b2c9

Please sign in to comment.