Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt caching in mlx_lm.server #1026

Merged
merged 5 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions llms/mlx_lm/SERVER.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
the generated prompt. If not provided, the default mappings are used.

- `stop`: (Optional) An array of strings or a single string. Thesse are
- `stop`: (Optional) An array of strings or a single string. These are
sequences of tokens on which the generation should stop.

- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
Expand Down Expand Up @@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \
started in.

- `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in.
relative to the directory the server was started in.

### Response Fields

- `id`: A unique identifier for the chat.

- `system_fingerprint`: A unique identifier for the system.

- `object`: Any of "chat.completions", "chat.completions.chunk" (for
streaming), or "text.completion".

- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).

- `created`: A time-stamp for when the request was processed.

- `choices`: A list of outputs. Each output is a dictionary containing the fields:
- `index`: The index in the list.
- `logprobs`: A dictionary containing the fields:
- `token_logprobs`: A list of the log probabilities for the generated
tokens.
- `tokens`: A list of the generated token ids.
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
top tokens (if requested) with their corresponding probabilities.
- `finish_reason`: The reason the completion ended. This can be either of
`"stop"` or `"length"`.
- `message`: The text response from the model.

- `usage`: A dictionary containing the fields:
- `prompt_tokens`: The number of prompt tokens processed.
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.

### List Models

Expand All @@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json"
This will return a list of locally available models where each model in the
list contains the following fields:

- `"id"`: The Hugging Face repo id.
- `"created"`: A timestamp representing the model creation time.
- `id`: The Hugging Face repo id.
- `created`: A time-stamp representing the model creation time.
121 changes: 93 additions & 28 deletions llms/mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,38 @@
import argparse
import json
import logging
import platform
import time
import uuid
import warnings
from dataclasses import dataclass, field
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)

import mlx.core as mx
from huggingface_hub import scan_cache_dir

from ._version import __version__
from .models.cache import make_prompt_cache
from .utils import generate_step, load


def get_system_fingerprint():
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"


class StopCondition(NamedTuple):
stop_met: bool
trim_length: int
Expand Down Expand Up @@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()


@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
model_key: Tuple[str, Optional[str]] = ("", None)
tokens: List[int] = field(default_factory=list)


class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
Expand Down Expand Up @@ -156,12 +182,21 @@ def load(self, model_path, adapter_path=None):


class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
def __init__(
self,
model_provider: ModelProvider,
*args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs,
):
"""
Create static request specific metadata
"""
self.created = int(time.time())
self.model_provider = model_provider
self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs)

def _set_cors_headers(self):
Expand Down Expand Up @@ -215,7 +250,9 @@ def do_POST(self):
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_tokens", 100)
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
Expand Down Expand Up @@ -343,7 +380,7 @@ def generate_response(
# Static response
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": self.object_type,
"model": self.requested_model,
"created": self.created,
Expand Down Expand Up @@ -388,16 +425,30 @@ def generate_response(

return response

def get_prompt_cache(self, prompt):
cache_len = len(self.prompt_cache.tokens)
if (
self.prompt_cache.model_key != self.model_provider.model_key
or cache_len >= len(prompt)
or self.prompt_cache.tokens != prompt[:cache_len]
):
self.prompt_cache.model_key = self.model_provider.model_key
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
else:
prompt = prompt[cache_len:]
self.prompt_cache.tokens.extend(prompt)
return prompt

def handle_completion(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate a response to a prompt and send it to the client in a single batch.

Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (List[int]): The tokenized prompt.
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
Expand All @@ -409,17 +460,21 @@ def handle_completion(
logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
for (token, logprobs), _ in zip(

prompt = self.get_prompt_cache(prompt)

for _, (token, logprobs) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
Expand All @@ -430,7 +485,7 @@ def handle_completion(
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(dict(top_token_info))
top_tokens.append(tuple(top_token_info))

token_logprobs.append(logprobs[token].item())

Expand All @@ -445,6 +500,7 @@ def handle_completion(
)
break

self.prompt_cache.tokens.extend(tokens)
detokenizer.finalize()
text = (
detokenizer.text
Expand Down Expand Up @@ -474,15 +530,15 @@ def handle_completion(

def handle_stream(
self,
prompt: mx.array,
prompt: List[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This in particular. Can't we handle a single or multiple (batched) prompt that falls back to the behavior for a single prompt by default?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can always update this type to List[List[int]] when the time comes.

stop_id_sequences: List[List[int]],
):
"""
Generate response to prompt and foward it to the client using a Server
Sent Events (SSE) stream.

Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function
"""
Expand All @@ -496,16 +552,19 @@ def handle_stream(
stop_sequence_suffix = None
logging.debug(f"Starting stream:")

for (token, _), _ in zip(
prompt = self.get_prompt_cache(prompt)

for _, (token, _) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
Expand All @@ -531,9 +590,12 @@ def handle_stream(
continue

new_text = detokenizer.last_segment
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
if new_text:
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()

self.prompt_cache.tokens.extend(tokens)

# check is there any remaining text to send
detokenizer.finalize()
Expand All @@ -559,7 +621,7 @@ def completion_usage_response(
):
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
angeloskath marked this conversation as resolved.
Show resolved Hide resolved
"object": "chat.completion",
"model": self.requested_model,
"created": self.created,
Expand All @@ -572,7 +634,7 @@ def completion_usage_response(
}
return response

def handle_chat_completions(self) -> mx.array:
def handle_chat_completions(self) -> List[int]:
"""
Handle a chat completion request.

Expand All @@ -587,7 +649,6 @@ def handle_chat_completions(self) -> mx.array:
self.object_type = (
"chat.completions.chunk" if self.stream else "chat.completions"
)

if (
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
Expand All @@ -602,9 +663,9 @@ def handle_chat_completions(self) -> mx.array:
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = self.tokenizer.encode(prompt)

return mx.array(prompt)
return prompt
awni marked this conversation as resolved.
Show resolved Hide resolved

def handle_text_completions(self) -> mx.array:
def handle_text_completions(self) -> List[int]:
"""
Handle a text completion request.

Expand All @@ -614,11 +675,8 @@ def handle_text_completions(self) -> mx.array:
# Determine response type
self.request_id = f"cmpl-{uuid.uuid4()}"
self.object_type = "text_completion"

assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"]
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
return self.tokenizer.encode(self.body["prompt"])
awni marked this conversation as resolved.
Show resolved Hide resolved

def do_GET(self):
"""
Expand Down Expand Up @@ -669,9 +727,16 @@ def run(
handler_class=APIHandler,
):
server_address = (host, port)
prompt_cache = PromptCache()
httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
lambda *args, **kwargs: handler_class(
model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
),
)
warnings.warn(
"mlx_lm.server is not recommended for production as "
Expand Down
23 changes: 23 additions & 0 deletions llms/tests/test_prompt_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc.

import copy
import os
import tempfile
import unittest
Expand Down Expand Up @@ -215,6 +216,28 @@ def test_trim_cache_with_generate(self):
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
)

def test_cache_copying(self):
cache = [KVCache()]

x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[0].update_and_fetch(x, x)

y = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(y, y)

old_cache = copy.deepcopy(cache)

trim_prompt_cache(cache, 1)

self.assertTrue(old_cache[0].offset, 11)
self.assertTrue(cache[0].offset, 10)

z = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(z, z)

self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions llms/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class DummyModelProvider:
def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
self.model_key = (HF_MODEL_PATH, None)

def load(self, model, adapter=None):
assert model in ["default_model", "chat_model"]
Expand Down