diff --git a/bert/convert.py b/bert/convert.py index 63b448b4c..8806ba0d8 100644 --- a/bert/convert.py +++ b/bert/convert.py @@ -1,6 +1,6 @@ import argparse -import numpy +import mlx.core as mx from transformers import AutoModel @@ -23,9 +23,9 @@ def convert(bert_model: str, mlx_model: str) -> None: model = AutoModel.from_pretrained(bert_model) # save the tensors tensors = { - replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items() + replace_key(key): mx.array(tensor) for key, tensor in model.state_dict().items() } - numpy.savez(mlx_model, **tensors) + mx.save_safetensors(mlx_model, tensors) if __name__ == "__main__": @@ -39,7 +39,7 @@ def convert(bert_model: str, mlx_model: str) -> None: parser.add_argument( "--mlx-model", type=str, - default="weights/bert-base-uncased.npz", + default="bert-base-uncased.safetensors", help="The output path for the MLX BERT weights.", ) args = parser.parse_args() diff --git a/bert/model.py b/bert/model.py index e7a19c99a..8fd234e8d 100644 --- a/bert/model.py +++ b/bert/model.py @@ -136,10 +136,7 @@ def load_model( def run(bert_model: str, mlx_model: str, batch: List[str]): model, tokenizer = load_model(bert_model, mlx_model) - - tokens = tokenizer(batch, return_tensors="np", padding=True) - tokens = {key: mx.array(v) for key, v in tokens.items()} - + tokens = tokenizer(batch, return_tensors="mlx", padding=True) return model(**tokens) @@ -149,13 +146,13 @@ def run(bert_model: str, mlx_model: str, batch: List[str]): "--bert-model", type=str, default="bert-base-uncased", - help="The huggingface name of the BERT model to save.", + help="The huggingface name of the BERT model.", ) parser.add_argument( "--mlx-model", type=str, - default="weights/bert-base-uncased.npz", - help="The path of the stored MLX BERT weights (npz file).", + default="bert-base-uncased.safetensors", + help="The path of the stored MLX BERT weights.", ) parser.add_argument( "--text", diff --git a/bert/requirements.txt b/bert/requirements.txt index a6b564c5e..9bb67eecd 100644 --- a/bert/requirements.txt +++ b/bert/requirements.txt @@ -1,3 +1,2 @@ mlx>=0.0.5 transformers -numpy diff --git a/bert/test.py b/bert/test.py index e5462ba3c..53e3c6b30 100644 --- a/bert/test.py +++ b/bert/test.py @@ -29,8 +29,8 @@ def run_torch(bert_model: str, batch: List[str]): parser.add_argument( "--mlx-model", type=str, - default="weights/bert-base-uncased.npz", - help="The path of the stored MLX BERT weights (npz file).", + default="bert-base-uncased.safetensors", + help="The path of the stored MLX BERT weights.", ) parser.add_argument( "--text", diff --git a/bert/weights/.gitignore b/bert/weights/.gitignore deleted file mode 100644 index 44662642a..000000000 --- a/bert/weights/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.npz \ No newline at end of file diff --git a/llms/export/.gitignore b/llms/export/.gitignore new file mode 100644 index 000000000..567609b12 --- /dev/null +++ b/llms/export/.gitignore @@ -0,0 +1 @@ +build/ diff --git a/llms/export/CMakeLists.txt b/llms/export/CMakeLists.txt new file mode 100644 index 000000000..e2c18fc9f --- /dev/null +++ b/llms/export/CMakeLists.txt @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.27) + +project(mlxlm LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +find_package( + Python 3.9 + COMPONENTS Interpreter Development.Module + REQUIRED) +execute_process( + COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE MLX_ROOT) +find_package(MLX CONFIG REQUIRED) + +add_library(mlxlm) +target_link_libraries(mlxlm PUBLIC mlx) + +add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/third_party) + +target_sources(mlxlm + PRIVATE + mlxlm.cpp + tokenizer.cpp) + +add_executable(main main.cpp) +target_link_libraries(main PRIVATE mlxlm) + +add_executable(test test.cpp) +target_link_libraries(test PRIVATE mlxlm) diff --git a/llms/export/README.md b/llms/export/README.md new file mode 100644 index 000000000..3d6fac8c5 --- /dev/null +++ b/llms/export/README.md @@ -0,0 +1,34 @@ +# Export LLMs to C++ + +Export language model inference from Python to run directly in C++. + +To run, first install the requirements: + +```bash +pip install -U mlx-lm +``` + +Then generate text from Python with: + +```bash +python export.py generate "How tall is K2?" +``` + +To export the generation function run: + +```bash +python export.py export +``` + +Then build the C++ code (requires CMake): + +```bash +cmake -B build -DCMAKE_BUILD_TYPE=Release +cmake --build build +``` + +And run the generation from C++ with: + +```bash +./build/main lama3.1-instruct-4bit "How tall is K2?" +``` diff --git a/llms/export/export.py b/llms/export/export.py new file mode 100644 index 000000000..36a2fe3a5 --- /dev/null +++ b/llms/export/export.py @@ -0,0 +1,171 @@ +import time +from pathlib import Path + +import fire +import mlx.core as mx +from mlx_lm import load + + +class ExportableCache: + + def __init__(self, keys=None, values=None, offset=0): + self.offset = offset + self.keys = keys + self.values = values + + def update_and_fetch(self, keys, values): + if self.keys is not None: + self.keys = mx.slice_update(self.keys, keys, self.offset, axes=(2,)) + self.values = mx.slice_update(self.values, values, self.offset, axes=(2,)) + else: + self.keys = keys + self.values = values + return self.keys, self.values + + @property + def state(self): + return self.keys, self.values + + +def expand(cache, mask=None, cache_step_size=256): + cache_size = cache[0].shape[-2] + new_size = cache_step_size * ((cache_size + cache_step_size) // cache_step_size) + + def expand_kv(x): + B, n_heads, _, head_dim = x.shape + new_x = mx.zeros((B, n_heads, new_size, head_dim), x.dtype) + new_x[..., : x.shape[2], :] = x + return new_x + + cache = [expand_kv(c) for c in cache] + if mask is None: + mask = mx.full(new_size, False) + mask[:cache_size] = True + else: + mask = mx.concatenate([mask, mx.full(cache_step_size, False)]) + return cache, mask + + +def causal_mask(N): + idx = mx.arange(N) + return idx[:, None] >= idx + + +def step(model, y, *state): + mask = state[-1] + if len(state) > 1: + cache, offset = state[:-2], state[-2] + cache = [ + ExportableCache(keys, values, offset) + for keys, values in zip(cache[::2], cache[1::2]) + ] + else: + cache = [ExportableCache() for i in range(len(model.model.layers))] + logits = model(y, cache=cache, mask=mask) + cache = [y for x in cache for y in x.state] + return logits, *cache + + +def generate_step(prompt, model, max_tokens): + mx.eval(model) + + compiled_step = mx.compile(lambda *args: step(model, *args), shapeless=True) + + def _step(*args): + logits, *cache = compiled_step(*args) + return mx.argmax(logits[:, -1], axis=-1), *cache + + y, *cache = _step(prompt, causal_mask(prompt.size)) + mx.async_eval(y) + offset = mx.array(prompt.size, mx.uint32) + cache, mask = expand(cache) + n = 0 + while True: + if n < max_tokens - 1: + if mask.size <= (prompt.size + n): + cache, mask = expand(cache, mask) + mask[prompt.size + n] = True + next_y, *cache = _step(y[None], *cache, offset, mask) + mx.async_eval(next_y) + offset += 1 + n += 1 + yield y.item() + if n == max_tokens: + break + y = next_y + + +def export( + model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", + path="llama3.1-instruct-4bit", +): + model, tokenizer = load(model) + + mx.eval(model) + + tokenizer.save_pretrained(path) + + _step = lambda *args: step(model, *args) + + # Make example inputs + y_prompt = mx.array([[0, 0]], mx.uint32) + y_gen = mx.array([[0]], mx.uint32) + offset = mx.array([0], mx.uint32) + + mask = causal_mask(y_prompt.size) + _, *cache = _step(y_prompt, mask) + + model_path = str(Path(path) / "model.mlxfn") + with mx.exporter(model_path, _step, shapeless=True) as exporter: + exporter(y_prompt, mask) + cache, mask = expand(cache) + exporter(y_gen, *cache, offset, mask) + + +def generate( + prompt, + model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", + max_tokens=128, +): + print("[INFO] Loading model from disk.") + model, tokenizer = load(model) + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + return_tensors="mlx", + ) + + print("[INFO] Starting generation...") + tic = time.time() + tokens = [] + + detokenizer = tokenizer.detokenizer + detokenizer.reset() + + for n, token in enumerate(generate_step(prompt, model, max_tokens)): + if n == 0: + prompt_tps = prompt.size / (time.time() - tic) + tic = time.time() + + if token in tokenizer.eos_token_ids: + break + detokenizer.add_token(token) + print(detokenizer.last_segment, end="", flush=True) + + detokenizer.finalize() + print(detokenizer.last_segment, flush=True) + gen_tps = (n + 1) / (time.time() - tic) + peak_memory = mx.metal.get_peak_memory() / 1e9 + print("=" * 10) + print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + print(f"Generation: {gen_tps:.3f} tokens-per-sec") + print(f"Peak RAM: {peak_memory:.3f} GB") + + +if __name__ == "__main__": + fire.Fire( + { + "generate": generate, + "export": export, + } + ) diff --git a/llms/export/main.cpp b/llms/export/main.cpp new file mode 100644 index 000000000..262a3014d --- /dev/null +++ b/llms/export/main.cpp @@ -0,0 +1,18 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlxlm.h" + +int main(int argc, char *argv[]) { + if (argc < 3) { + std::cerr << "Must provide the model path and prompt." << std::endl; + return 1; + } + auto path = std::string(argv[1]); + auto prompt = std::string(argv[2]); + + auto model = load_model(path + "/model.mlxfn"); + auto tokenizer = load_tokenizer(path); + generate(model, tokenizer, prompt); +} diff --git a/llms/export/mlxlm.cpp b/llms/export/mlxlm.cpp new file mode 100644 index 000000000..b5aa9a246 --- /dev/null +++ b/llms/export/mlxlm.cpp @@ -0,0 +1,120 @@ +// Copyright © 2024 Apple Inc. + +#include +#include +#include + +#include "mlxlm.h" + +namespace mx = mlx::core; + +#define seconds(x) \ + (std::chrono::duration_cast(x).count() / 1e9) +#define time_now() std::chrono::high_resolution_clock::now() + +// Maybe compile +std::function load_model(const std::string &path) { + return mx::compile(mx::import_function(path), /* shapeless = */ true); +} + +// Maybe make tokenizer virtual +BPETokenizer load_tokenizer(const std::string &path) { + return BPETokenizer(path); +} + +void generate(const std::function &model, + const BPETokenizer &tokenizer, const std::string &prompt, + int max_tokens /* = 256 */) { + + auto prompt_tokens = tokenizer.encode(prompt); + int prompt_size = prompt_tokens.size(); + auto y = mx::array(prompt_tokens.data(), {1, prompt_size}, mx::uint32); + + auto create_causal_mask = [](int N) { + auto indices = mx::arange(N); + return mx::expand_dims(indices, 1) >= indices; + }; + + // Helper to expand the cache and mask + auto expand = [](auto &args, auto &mask) { + constexpr int cache_step_size = 256; + int cache_size = args[1].shape(-2); + int new_size = + cache_step_size * ((cache_size + cache_step_size) / cache_step_size); + for (auto it = args.begin() + 1; it != args.end(); ++it) { + auto &x = *it; + auto shape = x.shape(); + shape[2] = new_size; + auto new_x = mx::zeros(shape, x.dtype()); + shape[2] = cache_size; + *it = + mx::slice_update(new_x, x, mx::Shape(x.ndim(), 0), std::move(shape)); + } + mask = + mx::slice_update(mx::full({new_size}, false), mask, {0}, {cache_size}); + }; + + auto tic = time_now(); + float prompt_time; + int n = 0; + + mx::Args args; + { + args = model({y, create_causal_mask(y.size())}); + auto logits = args[0]; + logits = slice(logits, {0, -1, 0}, logits.shape()); + y = argmax(logits, -1); + async_eval(y); + } + + auto offset = mx::array(prompt_size, mx::uint32); + std::vector tokens; + + auto mask = mx::full({prompt_size}, true); + expand(args, mask); + + for (; n < max_tokens; ++n) { + // Start next token decoding if needed + if (n < max_tokens - 1) { + args[0] = y; + auto m = prompt_size + n; + if (mask.size() <= m) { + expand(args, mask); + } + mask = mx::slice_update(mask, mx::array(true), {m}, {m + 1}); + args.push_back(offset); + args.push_back(mask); + args = model(args); + args[0] = argmax(args[0], -1); + offset = offset + 1u; + async_eval(args[0]); + } + + auto token = y.item(); + if (token == tokenizer.eos_token_id()) { + break; + } + tokens.push_back(token); + auto [result, complete] = tokenizer.try_decode(tokens); + if (complete) { + std::cout << result << std::flush; + tokens.clear(); + } + if (n == 0) { + prompt_time = seconds(time_now() - tic); + tic = time_now(); + } + + if (n < max_tokens - 1) { + y = args[0]; + } + } + auto result = tokenizer.decode(tokens); + std::cout << result << std::flush; + + auto gen_time = seconds(time_now() - tic); + std::cout << std::endl; + std::cout << std::setprecision(5) << "Prompt toks/sec " + << prompt_size / prompt_time << "\nGeneration toks/sec " + << (n + 1) / gen_time << std::endl; +} diff --git a/llms/export/mlxlm.h b/llms/export/mlxlm.h new file mode 100644 index 000000000..9dba3b0a1 --- /dev/null +++ b/llms/export/mlxlm.h @@ -0,0 +1,15 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "tokenizer.h" + +namespace mx = mlx::core; + +std::function load_model(const std::string &path); + +BPETokenizer load_tokenizer(const std::string &path); + +void generate(const std::function &model, + const BPETokenizer &tokenizer, const std::string &prompt, + int max_tokens = 256); diff --git a/llms/export/test.cpp b/llms/export/test.cpp new file mode 100644 index 000000000..0c44ab85b --- /dev/null +++ b/llms/export/test.cpp @@ -0,0 +1,23 @@ +// Copyright © 2024 Apple Inc. + +#include "tokenizer.h" +#include + +template void check(const T &x, const U &y) { + if (x != y) { + std::cerr << "Mismatch" << std::endl; + } +} + +void test_tokenizer(const std::string &path) { + BPETokenizer tokenizer(path); + check(tokenizer.encode("hello world!"), {128000, 15339, 1917, 0}); + check(tokenizer.decode({15339}), "hello"); + check(tokenizer.decode({0}), "!"); + check(tokenizer.decode({1917}), " world"); + check(tokenizer.encode("we'd see you say 世界你好真实好的很啊"), + {128000, 906, 4265, 220, 1518, 256, 499, 2019, 127365, 57668, 53901, + 89151, 41073, 110085, 101600, 102856}); +} + +int main(int argc, char *argv[]) { test_tokenizer("."); } diff --git a/llms/export/third_party/CMakeLists.txt b/llms/export/third_party/CMakeLists.txt new file mode 100644 index 000000000..0496ea57e --- /dev/null +++ b/llms/export/third_party/CMakeLists.txt @@ -0,0 +1,20 @@ +include(FetchContent) + +FetchContent_Declare( + json + URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) +FetchContent_MakeAvailable(json) +target_include_directories( + mlxlm PRIVATE $) + +execute_process( + COMMAND zsh "${CMAKE_CURRENT_SOURCE_DIR}/download_unicode.sh" "${CMAKE_CURRENT_BINARY_DIR}" + COMMAND_ERROR_IS_FATAL ANY +) + +target_sources(mlxlm + PRIVATE + ${CMAKE_CURRENT_BINARY_DIR}/unicode.cpp + ${CMAKE_CURRENT_BINARY_DIR}/unicode-data.cpp) + +target_include_directories(mlxlm PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/..) diff --git a/llms/export/third_party/download_unicode.sh b/llms/export/third_party/download_unicode.sh new file mode 100644 index 000000000..6e52fe387 --- /dev/null +++ b/llms/export/third_party/download_unicode.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +commit=1204f9727005974587d6fc1dcd4d4f0ead87c856 +url=https://raw.githubusercontent.com/ggerganov/llama.cpp/${commit}/src/ + +for file in 'unicode.cpp' 'unicode.h' 'unicode-data.cpp' 'unicode-data.h' +do + curl -OL ${url}/${file} --output-dir $1 2>/dev/null +done diff --git a/llms/export/tokenizer.cpp b/llms/export/tokenizer.cpp new file mode 100644 index 000000000..5e68a43bd --- /dev/null +++ b/llms/export/tokenizer.cpp @@ -0,0 +1,201 @@ + +#include +#include +#include +#include +#include + +#include "third_party/unicode.h" +#include "tokenizer.h" + +using json = nlohmann::json; + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +std::pair utf8_to_utf16(const std::string &s) { + static std::string replace_str = std::string(1, 0xFF); + static std::wstring replace_wstr = std::wstring(1, 0xFFFD); + std::wstring_convert> cvt(replace_str, + replace_wstr); + auto out = cvt.from_bytes(s); + return {out, cvt.converted()}; +} +#pragma GCC diagnostic pop + +auto make_byte_decoder() { + std::unordered_map byte_decoder; + std::vector limits = {0, '!', '~' + 1, L'¡', + L'¬' + 1, L'®', L'ÿ' + 1}; + char n = 0; + for (int i = 0; i < limits.size() - 1; ++i) { + auto start = limits[i]; + auto stop = limits[i + 1]; + if (i % 2 == 0) { + for (int b = start; b < stop; ++b) { + byte_decoder[256 + n++] = b; + } + } else { + for (int b = start; b < stop; ++b) { + byte_decoder[b] = b; + } + } + } + return byte_decoder; +} + +auto BPETokenizer::byte_decoder_ = make_byte_decoder(); + +BPETokenizer::BPETokenizer(const std::string &path_) { + auto path = std::filesystem::path(path_); + std::ifstream ifs(path / "tokenizer.json"); + auto tokenizer = json::parse(ifs); + auto model = tokenizer["model"]; + token_to_id_ = model["vocab"]; + id_to_token_.resize(token_to_id_.size()); + for (auto &[s, id] : token_to_id_) { + if (id >= id_to_token_.size()) { + id_to_token_.resize(id + 1); + } + id_to_token_[id] = s; + } + std::string type = model["type"]; + auto merges = model["merges"]; + for (auto &s : merges) { + if (s.is_string()) { + merges_.emplace(s, merges_.size()); + } else { + std::string s1 = s[0]; + std::string s2 = s[1]; + merges_.emplace(s1 + " " + s2, merges_.size()); + } + } + + auto added_tokens = tokenizer["added_tokens"]; + for (auto &added_token : added_tokens) { + int id = added_token["id"]; + if (id >= id_to_token_.size()) { + id_to_token_.resize(id + 1); + } + id_to_token_[id] = added_token["content"]; + if (id_to_token_[id] == "<|begin_of_text|>") { + bos_id_ = id; + } else if (id_to_token_[id] == "<|eot_id|>") { + eos_id_ = id; + } + } + + // Currently hardcoded to Llama3 BPE regex + pre_tokenizer_regex_ = { + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}" + "\\p{N}]?\\p{L}+|\\p{N}{1,3}| " + "?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"}; +} + +std::vector BPETokenizer::encode(std::string text) const { + + auto segments = unicode_regex_split(text, pre_tokenizer_regex_); + + auto one_step_merge = [this](std::string segment, std::vector &splits) { + int merge_idx; + int rank = INT32_MAX; + std::string candidate; + for (int i = 0; i < splits.size() - 2; ++i) { + auto start = splits[i]; + auto mid = splits[i + 1]; + auto end = splits[i + 2]; + candidate.clear(); + candidate.insert(candidate.end(), segment.begin() + start, + segment.begin() + mid); + candidate += " "; + candidate.insert(candidate.end(), segment.begin() + mid, + segment.begin() + end); + if (auto it = merges_.find(candidate); it != merges_.end()) { + if (it->second < rank) { + merge_idx = i; + rank = it->second; + } + } + } + if (rank == INT32_MAX) { + return false; + } + auto start = splits[merge_idx]; + auto mid = splits[merge_idx + 1]; + auto end = splits[merge_idx + 2]; + std::string merge_l = segment.substr(start, mid - start); + std::string merge_r = segment.substr(mid, end - mid); + for (int i = splits.size() - 2; i >= 0; --i) { + auto start = splits[i]; + auto mid = splits[i + 1]; + auto end = splits[i + 2]; + if (segment.substr(start, mid - start) == merge_l && + segment.substr(mid, end - mid) == merge_r) { + splits.erase(splits.begin() + i + 1); + i -= 1; + } + } + return true; + }; + + std::vector ids; + ids.push_back(bos_id_); + + // Initialize merges to integer list + auto merge_segment = [&ids, &one_step_merge, + this](const std::string &segment) { + std::vector splits; + for (int i = 0; i < segment.size(); ++i) { + splits.push_back(i); + if (static_cast(segment[i]) >= 128) { + i++; + } + } + splits.push_back(segment.size()); + + while (one_step_merge(segment, splits)) { + }; + for (int i = 0; i < splits.size() - 1; ++i) { + auto start = splits[i]; + auto end = splits[i + 1]; + std::string s = segment.substr(start, end - start); + if (auto it = token_to_id_.find(s); it != token_to_id_.end()) { + ids.push_back(it->second); + } else { + throw std::runtime_error("UNK ENCOUNTERED"); + } + } + }; + + for (auto &segment : segments) { + merge_segment(segment); + } + return ids; +} + +std::string BPETokenizer::id_to_bytes(int id) const { + std::string token; + auto [wide_token, _] = utf8_to_utf16(id_to_token_[id]); + token.resize(wide_token.size()); + for (int i = 0; i < wide_token.size(); ++i) { + token[i] = byte_decoder_[wide_token[i]]; + } + return token; +} + +std::pair +BPETokenizer::try_decode(const std::vector &ids) const { + std::string text; + for (auto id : ids) { + text += id_to_bytes(id); + } + auto [_, converted] = utf8_to_utf16(text); + bool complete = converted == text.size(); + text.resize(converted); + return {text, complete}; +} + +std::string BPETokenizer::decode(const std::vector &ids) const { + return try_decode(ids).first; +} + +int BPETokenizer::eos_token_id() const { return eos_id_; } diff --git a/llms/export/tokenizer.h b/llms/export/tokenizer.h new file mode 100644 index 000000000..64dd2fe42 --- /dev/null +++ b/llms/export/tokenizer.h @@ -0,0 +1,37 @@ +// Copyright © 2024 Apple Inc. + +#include +#include +#include + +#pragma once + +/** BPE Tokenizer API */ +class BPETokenizer { +public: + BPETokenizer(const std::string &path); + + /** Encode a string of text to token integer ids. */ + std::vector encode(std::string text) const; + + /** Try to decode the vector of ids to text. The text is truncated to + * include only the fully decodable tokens. */ + std::string decode(const std::vector &ids) const; + + /** Try to decode the vector of ids to text. The second return value + * indicates if the decoding completed. The text is truncated to include + * only the fully decodable tokens. */ + std::pair try_decode(const std::vector &ids) const; + + int eos_token_id() const; + +private: + std::unordered_map token_to_id_; + std::vector id_to_token_; + std::unordered_map merges_; + int bos_id_; + int eos_id_; + static std::unordered_map byte_decoder_; + std::string id_to_bytes(int id) const; + std::vector pre_tokenizer_regex_; +}; diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 7b452ea47..4a789dcf3 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -69,14 +69,12 @@ def __call__( mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: - B, L, D = x.shape - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + queries = mx.unflatten(queries, -1, (self.n_heads, -1)).transpose(0, 2, 1, 3) + keys = mx.unflatten(keys, -1, (self.n_kv_heads, -1)).transpose(0, 2, 1, 3) + values = mx.unflatten(values, -1, (self.n_kv_heads, -1)).transpose(0, 2, 1, 3) if cache is not None: queries = self.rope(queries, offset=cache.offset) @@ -90,7 +88,7 @@ def __call__( queries, keys, values, cache=cache, scale=self.scale, mask=mask ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + output = output.transpose(0, 2, 1, 3).flatten(-2, -1) return self.o_proj(output)