Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export and run LLMs in C++ #1197

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llms/export/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build/
42 changes: 42 additions & 0 deletions llms/export/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
cmake_minimum_required(VERSION 3.27)

project(mlxlm LANGUAGES CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)

add_library(mlxlm)
target_link_libraries(mlxlm PUBLIC mlx)

include(FetchContent)

FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlxlm PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)

target_sources(mlxlm
PRIVATE
mlxlm.cpp
tokenizer.cpp
unicode.cpp
unicode_data.cpp)

add_executable(main main.cpp)
target_link_libraries(main PRIVATE mlxlm)

add_executable(test test.cpp)
target_link_libraries(test PRIVATE mlxlm)
34 changes: 34 additions & 0 deletions llms/export/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Export LLMs to C++

Export language model inference from Python to run directly in C++.

To run, first install the requirements:

```bash
pip install -U mlx-lm
```

Then generate text from Python with:

```bash
python export.py generate "How tall is K2?"
```

To export the generation function run:

```bash
python export.py export
```

Then build the C++ code (requires CMake):

```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```

And run the generation from C++ with:

```bash
./build/main lama3.1-instruct-4bit "How tall is K2?"
```
171 changes: 171 additions & 0 deletions llms/export/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import time
from pathlib import Path

import fire
import mlx.core as mx
from mlx_lm import load


class ExportableCache:

def __init__(self, keys=None, values=None, offset=0):
self.offset = offset
self.keys = keys
self.values = values

def update_and_fetch(self, keys, values):
if self.keys is not None:
self.keys = mx.slice_update(self.keys, keys, self.offset, axes=(2,))
self.values = mx.slice_update(self.values, values, self.offset, axes=(2,))
else:
self.keys = keys
self.values = values
return self.keys, self.values

@property
def state(self):
return self.keys, self.values


def expand(cache, mask=None, cache_step_size=256):
cache_size = cache[0].shape[-2]
new_size = cache_step_size * ((cache_size + cache_step_size) // cache_step_size)

def expand_kv(x):
B, n_heads, _, head_dim = x.shape
new_x = mx.zeros((B, n_heads, new_size, head_dim), x.dtype)
new_x[..., : x.shape[2], :] = x
return new_x

cache = [expand_kv(c) for c in cache]
if mask is None:
mask = mx.full(new_size, False)
mask[:cache_size] = True
else:
mask = mx.concatenate([mask, mx.full(cache_step_size, False)])
return cache, mask


def causal_mask(N):
idx = mx.arange(N)
return idx[:, None] >= idx


def step(model, y, *state):
mask = state[-1]
if len(state) > 1:
cache, offset = state[:-2], state[-2]
cache = [
ExportableCache(keys, values, offset)
for keys, values in zip(cache[::2], cache[1::2])
]
else:
cache = [ExportableCache() for i in range(len(model.model.layers))]
logits = model(y, cache=cache, mask=mask)
cache = [y for x in cache for y in x.state]
return logits, *cache


def generate_step(prompt, model, max_tokens):
mx.eval(model)

compiled_step = mx.compile(lambda *args: step(model, *args), shapeless=True)

def _step(*args):
logits, *cache = compiled_step(*args)
return mx.argmax(logits[:, -1], axis=-1), *cache

y, *cache = _step(prompt, causal_mask(prompt.size))
mx.async_eval(y)
offset = mx.array(prompt.size, mx.uint32)
cache, mask = expand(cache)
n = 0
while True:
if n < max_tokens - 1:
if mask.size <= (prompt.size + n):
cache, mask = expand(cache, mask)
mask[prompt.size + n] = True
next_y, *cache = _step(y[None], *cache, offset, mask)
mx.async_eval(next_y)
offset += 1
n += 1
yield y.item()
if n == max_tokens:
break
y = next_y


def export(
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
path="llama3.1-instruct-4bit",
):
model, tokenizer = load(model)

mx.eval(model)

tokenizer.save_pretrained(path)

_step = lambda *args: step(model, *args)

# Make example inputs
y_prompt = mx.array([[0, 0]], mx.uint32)
y_gen = mx.array([[0]], mx.uint32)
offset = mx.array([0], mx.uint32)

mask = causal_mask(y_prompt.size)
_, *cache = _step(y_prompt, mask)

model_path = str(Path(path) / "model.mlxfn")
with mx.exporter(model_path, _step, shapeless=True) as exporter:
exporter(y_prompt, mask)
cache, mask = expand(cache)
exporter(y_gen, *cache, offset, mask)


def generate(
prompt,
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
max_tokens=128,
):
print("[INFO] Loading model from disk.")
model, tokenizer = load(model)
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="mlx",
)

print("[INFO] Starting generation...")
tic = time.time()
tokens = []

detokenizer = tokenizer.detokenizer
detokenizer.reset()

for n, token in enumerate(generate_step(prompt, model, max_tokens)):
if n == 0:
prompt_tps = prompt.size / (time.time() - tic)
tic = time.time()

if token in tokenizer.eos_token_ids:
break
detokenizer.add_token(token)
print(detokenizer.last_segment, end="", flush=True)

detokenizer.finalize()
print(detokenizer.last_segment, flush=True)
gen_tps = (n + 1) / (time.time() - tic)
peak_memory = mx.metal.get_peak_memory() / 1e9
print("=" * 10)
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
print(f"Peak RAM: {peak_memory:.3f} GB")


if __name__ == "__main__":
fire.Fire(
{
"generate": generate,
"export": export,
}
)
18 changes: 18 additions & 0 deletions llms/export/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright © 2024 Apple Inc.

#include <iostream>

#include "mlxlm.h"

int main(int argc, char *argv[]) {
if (argc < 3) {
std::cerr << "Must provide the model path and prompt." << std::endl;
return 1;
}
auto path = std::string(argv[1]);
auto prompt = std::string(argv[2]);

auto model = load_model(path + "/model.mlxfn");
auto tokenizer = load_tokenizer(path);
generate(model, tokenizer, prompt);
}
119 changes: 119 additions & 0 deletions llms/export/mlxlm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright © 2024 Apple Inc.

#include <chrono>
#include <iomanip>
#include <iostream>

#include "mlxlm.h"

namespace mx = mlx::core;

#define seconds(x) \
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e9)
#define time_now() std::chrono::high_resolution_clock::now()

// Maybe compile
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are already doing that :-)

std::function<mx::Args(mx::Args)> load_model(const std::string& path) {
return mx::compile(mx::import_function(path), /* shapeless = */ true);
}

// Maybe make tokenizer virtual
BPETokenizer load_tokenizer(const std::string& path) {
return BPETokenizer(path);
}

void generate(
const std::function<mx::Args(mx::Args)>& model,
const BPETokenizer& tokenizer,
const std::string& prompt,
int max_tokens /* = 256 */) {

auto prompt_tokens = tokenizer.encode(prompt);
int prompt_size = prompt_tokens.size();
auto y = mx::array(prompt_tokens.data(), {1, prompt_size}, mx::uint32);

auto create_causal_mask = [](int N) {
auto indices = mx::arange(N);
return mx::expand_dims(indices, 1) >= indices;
};

// Helper to expand the cache and mask
auto expand = [](auto& args, auto& mask) {
constexpr int cache_step_size = 256;
int cache_size = args[1].shape(-2);
int new_size = cache_step_size * ((cache_size + cache_step_size) / cache_step_size);
for (auto it = args.begin() + 1; it != args.end(); ++it) {
auto& x = *it;
auto shape = x.shape();
shape[2] = new_size;
auto new_x = mx::zeros(shape, x.dtype());
shape[2] = cache_size;
*it = mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape));
}
mask = mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size});
};

auto tic = time_now();
float prompt_time;
int n = 0;

mx::Args args;
{
args = model({y, create_causal_mask(y.size())});
auto logits = args[0];
logits = slice(logits, {0, -1, 0}, logits.shape());
y = argmax(logits, -1);
async_eval(y);
}

auto offset = mx::array(prompt_size, mx::uint32);
std::vector<int> tokens;

auto mask = mx::full({prompt_size}, true);
expand(args, mask);

for (; n < max_tokens; ++n) {
// Start next token decoding if needed
if (n < max_tokens - 1) {
args[0] = y;
auto m = prompt_size + n;
if (mask.size() <= m) {
expand(args, mask);
}
mask = mx::slice_update(mask, mx::array(true), {m}, {m + 1});
args.push_back(offset);
args.push_back(mask);
args = model(args);
args[0] = argmax(args[0], -1);
offset = offset + 1u;
async_eval(args[0]);
}

auto token = y.item<int>();
if (token == tokenizer.eos_token_id()) {
break;
}
tokens.push_back(token);
auto [result, complete] = tokenizer.try_decode(tokens);
if (complete) {
std::cout << result << std::flush;
tokens.clear();
}
if (n == 0) {
prompt_time = seconds(time_now() - tic);
tic = time_now();
}

if (n < max_tokens - 1) {
y = args[0];
}
}
auto result = tokenizer.decode(tokens);
std::cout << result << std::flush;

auto gen_time = seconds(time_now() - tic);
std::cout << std::endl;
std::cout << std::setprecision(5) << "Prompt toks/sec "
<< prompt_size / prompt_time << "\nGeneration toks/sec "
<< (n + 1) / gen_time << std::endl;
}
Loading