From cdba586b679cef1721c3599328acb0729f8af9bf Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 9 Oct 2024 12:46:44 -0700 Subject: [PATCH 1/5] caching in server --- llms/mlx_lm/SERVER.md | 24 +++++++ llms/mlx_lm/server.py | 112 +++++++++++++++++++++++++------- llms/tests/test_prompt_cache.py | 23 +++++++ 3 files changed, 134 insertions(+), 25 deletions(-) diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 55be1c9ca..58a163920 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -86,6 +86,30 @@ curl localhost:8080/v1/chat/completions \ - `adapters`: (Optional) A string path to low-rank adapters. The path must be rlative to the directory the server was started in. +### Response Fields + +- `id`: A unique identifier for the chat. +- `system_fingerprint`: A unique identifier for the system. +- `object`: Any of "chat.completions", "chat.completions.chunk" (for + streaming), or "text.completion". +- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`). +- `created`: A timestamp for when the request was processed. +- `choices`: A list of outputs. Each output is a dictionary containing the fields: + - `index`: The index in the list. + - `logprobs`: A dictionary containing the fields: + - `token_logprobs`: A list of the log probabilities for the generated + tokens. + - `tokens`: A list of the generated token ids. + - `top_logprobs`: A list of lists. Each list contains the `logprobs` + top tokens (if requested) with their corresponding probabilities. + - `finish_reason`: The reason the completion ended. This can be either of + `"stop"` or `"length"`. + - `message`: The text response from the model. +- `usage`: A dictionary containing the fields: + - `prompt_tokens`: The number of prompt tokens processed. + - `completion_tokens`: The number of tokens generated. + - `total_tokens`: The total number of tokens, i.e. the sum of the above two fields. + ### List Models Use the `v1/models` endpoint to list available models: diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 42962b547..f87c1d005 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -3,16 +3,30 @@ import argparse import json import logging +import platform import time import uuid import warnings +from dataclasses import dataclass, field from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union +from typing import ( + Any, + Dict, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) import mlx.core as mx from huggingface_hub import scan_cache_dir +from ._version import __version__ +from .models.cache import make_prompt_cache from .utils import generate_step, load @@ -94,6 +108,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): return prompt.rstrip() +@dataclass +class PromptCache: + cache: List[Any] = field(default_factory=list) + model_key: Tuple[str, Optional[str]] = ("", None) + tokens: List[int] = field(default_factory=list) + + class ModelProvider: def __init__(self, cli_args: argparse.Namespace): """Load models on demand and persist them across the whole process.""" @@ -156,12 +177,21 @@ def load(self, model_path, adapter_path=None): class APIHandler(BaseHTTPRequestHandler): - def __init__(self, model_provider: ModelProvider, *args, **kwargs): + def __init__( + self, + model_provider: ModelProvider, + prompt_cache: PromptCache, + system_fingerprint: str, + *args, + **kwargs, + ): """ Create static request specific metadata """ self.created = int(time.time()) self.model_provider = model_provider + self.prompt_cache = prompt_cache + self.system_fingerprint = system_fingerprint super().__init__(*args, **kwargs) def _set_cors_headers(self): @@ -215,7 +245,9 @@ def do_POST(self): self.stream_options = self.body.get("stream_options", None) self.requested_model = self.body.get("model", "default_model") self.adapter = self.body.get("adapters", None) - self.max_tokens = self.body.get("max_tokens", 100) + self.max_tokens = self.body.get("max_completion_tokens", None) + if self.max_tokens is None: + self.max_tokens = self.body.get("max_tokens", 512) self.temperature = self.body.get("temperature", 1.0) self.top_p = self.body.get("top_p", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0) @@ -343,7 +375,7 @@ def generate_response( # Static response response = { "id": self.request_id, - "system_fingerprint": f"fp_{uuid.uuid4()}", + "system_fingerprint": self.system_fingerprint, "object": self.object_type, "model": self.requested_model, "created": self.created, @@ -388,16 +420,30 @@ def generate_response( return response + def get_prompt_cache(self, prompt): + cache_len = len(self.prompt_cache.tokens) + if ( + self.prompt_cache.model_key != self.model_provider.model_key + or cache_len >= len(prompt) + or self.prompt_cache.tokens != prompt[:cache_len] + ): + self.prompt_cache.model_key = self.model_provider.model_key + self.prompt_cache.cache = make_prompt_cache(self.model_provider.model) + else: + prompt = prompt[cache_len:] + self.prompt_cache.tokens.extend(prompt) + return prompt + def handle_completion( self, - prompt: mx.array, + prompt: List[int], stop_id_sequences: List[List[int]], ): """ Generate a response to a prompt and send it to the client in a single batch. Args: - prompt (mx.array): The prompt, in token form inside of a mlx array + prompt (List[int]): The tokenized prompt. stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ @@ -409,7 +455,12 @@ def handle_completion( logging.debug(f"Starting completion:") token_logprobs = [] top_tokens = [] - for (token, logprobs), _ in zip( + + prompt = self.get_prompt_cache(prompt) + prompt = mx.array(prompt) + + for _, (token, logprobs) in zip( + range(self.max_tokens), generate_step( prompt=prompt, model=self.model, @@ -418,8 +469,8 @@ def handle_completion( repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, logit_bias=self.logit_bias, + prompt_cache=self.prompt_cache.cache, ), - range(self.max_tokens), ): detokenizer.add_token(token) logging.debug(detokenizer.text) @@ -430,7 +481,7 @@ def handle_completion( top_indices = sorted_indices[: self.logprobs] top_logprobs = logprobs[top_indices] top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) - top_tokens.append(dict(top_token_info)) + top_tokens.append(tuple(top_token_info)) token_logprobs.append(logprobs[token].item()) @@ -445,6 +496,7 @@ def handle_completion( ) break + self.prompt_cache.tokens.extend(tokens) detokenizer.finalize() text = ( detokenizer.text @@ -474,7 +526,7 @@ def handle_completion( def handle_stream( self, - prompt: mx.array, + prompt: List[int], stop_id_sequences: List[List[int]], ): """ @@ -482,7 +534,7 @@ def handle_stream( Sent Events (SSE) stream. Args: - prompt (mx.array): The prompt, in token form inside of a mlx array + prompt (mx.array): The tokenized prompt stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ @@ -496,16 +548,20 @@ def handle_stream( stop_sequence_suffix = None logging.debug(f"Starting stream:") - for (token, _), _ in zip( + prompt = self.get_prompt_cache(prompt) + prompt = mx.array(prompt) + + for _, (token, _) in zip( + range(self.max_tokens), generate_step( - prompt=prompt, + prompt=mx.array(prompt), model=self.model, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, + prompt_cache=self.prompt_cache.cache, ), - range(self.max_tokens), ): detokenizer.add_token(token) logging.debug(detokenizer.text) @@ -531,9 +587,12 @@ def handle_stream( continue new_text = detokenizer.last_segment - response = self.generate_response(new_text, None) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() + if new_text: + response = self.generate_response(new_text, None) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + + self.prompt_cache.tokens.extend(tokens) # check is there any remaining text to send detokenizer.finalize() @@ -559,7 +618,7 @@ def completion_usage_response( ): response = { "id": self.request_id, - "system_fingerprint": f"fp_{uuid.uuid4()}", + "system_fingerprint": self.system_fingerprint, "object": "chat.completion", "model": self.requested_model, "created": self.created, @@ -587,7 +646,6 @@ def handle_chat_completions(self) -> mx.array: self.object_type = ( "chat.completions.chunk" if self.stream else "chat.completions" ) - if ( hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template @@ -602,7 +660,7 @@ def handle_chat_completions(self) -> mx.array: prompt = convert_chat(body["messages"], body.get("role_mapping")) prompt = self.tokenizer.encode(prompt) - return mx.array(prompt) + return prompt def handle_text_completions(self) -> mx.array: """ @@ -614,11 +672,8 @@ def handle_text_completions(self) -> mx.array: # Determine response type self.request_id = f"cmpl-{uuid.uuid4()}" self.object_type = "text_completion" - assert "prompt" in self.body, "Request did not contain a prompt" - prompt_text = self.body["prompt"] - prompt = self.tokenizer.encode(prompt_text) - return mx.array(prompt) + return self.tokenizer.encode(self.body["prompt"]) def do_GET(self): """ @@ -669,9 +724,16 @@ def run( handler_class=APIHandler, ): server_address = (host, port) + prompt_cache = PromptCache() + system_fingerprint = ( + f"{__version__}-{mx.__version__}-{platform.platform()}-" + f"{mx.metal.device_info().get('architecture', '')}" + ) httpd = server_class( server_address, - lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs), + lambda *args, **kwargs: handler_class( + model_provider, prompt_cache, system_fingerprint, *args, **kwargs + ), ) warnings.warn( "mlx_lm.server is not recommended for production as " diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 3c1ef49b3..64cd9486d 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -1,5 +1,6 @@ # Copyright © 2024 Apple Inc. +import copy import os import tempfile import unittest @@ -215,6 +216,28 @@ def test_trim_cache_with_generate(self): all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) ) + def test_cache_copying(self): + cache = [KVCache()] + + x = mx.random.uniform(shape=(1, 8, 10, 4)) + cache[0].update_and_fetch(x, x) + + y = mx.random.uniform(shape=(1, 8, 1, 4)) + cache[0].update_and_fetch(y, y) + + old_cache = copy.deepcopy(cache) + + trim_prompt_cache(cache, 1) + + self.assertTrue(old_cache[0].offset, 11) + self.assertTrue(cache[0].offset, 10) + + z = mx.random.uniform(shape=(1, 8, 1, 4)) + cache[0].update_and_fetch(z, z) + + self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) + self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) + if __name__ == "__main__": unittest.main() From d85010bf4bc161c4d2b4a502e33a0d04fbc52385 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 9 Oct 2024 12:49:32 -0700 Subject: [PATCH 2/5] nits --- llms/mlx_lm/SERVER.md | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 58a163920..2976a09fc 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \ - `role_mapping`: (Optional) A dictionary to customize the role prefixes in the generated prompt. If not provided, the default mappings are used. -- `stop`: (Optional) An array of strings or a single string. Thesse are +- `stop`: (Optional) An array of strings or a single string. These are sequences of tokens on which the generation should stop. - `max_tokens`: (Optional) An integer specifying the maximum number of tokens @@ -84,16 +84,21 @@ curl localhost:8080/v1/chat/completions \ started in. - `adapters`: (Optional) A string path to low-rank adapters. The path must be - rlative to the directory the server was started in. + relative to the directory the server was started in. ### Response Fields - `id`: A unique identifier for the chat. + - `system_fingerprint`: A unique identifier for the system. + - `object`: Any of "chat.completions", "chat.completions.chunk" (for streaming), or "text.completion". + - `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`). -- `created`: A timestamp for when the request was processed. + +- `created`: A time-stamp for when the request was processed. + - `choices`: A list of outputs. Each output is a dictionary containing the fields: - `index`: The index in the list. - `logprobs`: A dictionary containing the fields: @@ -105,6 +110,7 @@ curl localhost:8080/v1/chat/completions \ - `finish_reason`: The reason the completion ended. This can be either of `"stop"` or `"length"`. - `message`: The text response from the model. + - `usage`: A dictionary containing the fields: - `prompt_tokens`: The number of prompt tokens processed. - `completion_tokens`: The number of tokens generated. @@ -121,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json" This will return a list of locally available models where each model in the list contains the following fields: -- `"id"`: The Hugging Face repo id. -- `"created"`: A timestamp representing the model creation time. +- `id`: The Hugging Face repo id. +- `created`: A time-stamp representing the model creation time. From d6222ae7ff2ddf662a5e79ca63bd8bcaacb36ef9 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 9 Oct 2024 13:17:56 -0700 Subject: [PATCH 3/5] fix tests --- llms/mlx_lm/server.py | 25 ++++++++++++++++--------- llms/tests/test_server.py | 1 + 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index f87c1d005..47a725994 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -30,6 +30,13 @@ from .utils import generate_step, load +def get_system_fingerprint(): + return ( + f"{__version__}-{mx.__version__}-{platform.platform()}-" + f"{mx.metal.device_info().get('architecture', '')}" + ) + + class StopCondition(NamedTuple): stop_met: bool trim_length: int @@ -180,9 +187,9 @@ class APIHandler(BaseHTTPRequestHandler): def __init__( self, model_provider: ModelProvider, - prompt_cache: PromptCache, - system_fingerprint: str, *args, + prompt_cache: Optional[PromptCache] = None, + system_fingerprint: Optional[str] = None, **kwargs, ): """ @@ -190,8 +197,8 @@ def __init__( """ self.created = int(time.time()) self.model_provider = model_provider - self.prompt_cache = prompt_cache - self.system_fingerprint = system_fingerprint + self.prompt_cache = prompt_cache or PromptCache() + self.system_fingerprint = system_fingerprint or get_system_fingerprint() super().__init__(*args, **kwargs) def _set_cors_headers(self): @@ -725,14 +732,14 @@ def run( ): server_address = (host, port) prompt_cache = PromptCache() - system_fingerprint = ( - f"{__version__}-{mx.__version__}-{platform.platform()}-" - f"{mx.metal.device_info().get('architecture', '')}" - ) httpd = server_class( server_address, lambda *args, **kwargs: handler_class( - model_provider, prompt_cache, system_fingerprint, *args, **kwargs + model_provider, + prompt_cache=prompt_cache, + system_fingerprint=get_system_fingerprint(), + *args, + **kwargs, ), ) warnings.warn( diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index cbcccfbef..ad17554d1 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -14,6 +14,7 @@ class DummyModelProvider: def __init__(self): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" self.model, self.tokenizer = load(HF_MODEL_PATH) + self.model_key = (HF_MODEL_PATH, None) def load(self, model, adapter=None): assert model in ["default_model", "chat_model"] From 1b05b51dc533cff377713aa5edbe2bdc69cbcf16 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 9 Oct 2024 19:13:33 -0700 Subject: [PATCH 4/5] don't throw if no metal --- llms/mlx_lm/server.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 47a725994..eadf951b6 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -31,10 +31,8 @@ def get_system_fingerprint(): - return ( - f"{__version__}-{mx.__version__}-{platform.platform()}-" - f"{mx.metal.device_info().get('architecture', '')}" - ) + gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else "" + return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" class StopCondition(NamedTuple): From 5c4e6ce279bdea0c2bb364195c34b4d15f4cc630 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Oct 2024 10:47:58 -0700 Subject: [PATCH 5/5] comments --- llms/mlx_lm/server.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index eadf951b6..ec6599695 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -462,12 +462,11 @@ def handle_completion( top_tokens = [] prompt = self.get_prompt_cache(prompt) - prompt = mx.array(prompt) for _, (token, logprobs) in zip( range(self.max_tokens), generate_step( - prompt=prompt, + prompt=mx.array(prompt), model=self.model, temp=self.temperature, top_p=self.top_p, @@ -554,7 +553,6 @@ def handle_stream( logging.debug(f"Starting stream:") prompt = self.get_prompt_cache(prompt) - prompt = mx.array(prompt) for _, (token, _) in zip( range(self.max_tokens), @@ -636,7 +634,7 @@ def completion_usage_response( } return response - def handle_chat_completions(self) -> mx.array: + def handle_chat_completions(self) -> List[int]: """ Handle a chat completion request. @@ -667,7 +665,7 @@ def handle_chat_completions(self) -> mx.array: return prompt - def handle_text_completions(self) -> mx.array: + def handle_text_completions(self) -> List[int]: """ Handle a text completion request.