Skip to content

Commit

Permalink
MLX model support (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
g-eoj authored Feb 12, 2025
1 parent bca3a9b commit 9b96199
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 3 deletions.
16 changes: 15 additions & 1 deletion docs/source/en/guided_tour.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ To initialize a minimal agent, you need at least these two arguments:
- [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood and supports all Inference Providers on the Hub.
- [`LiteLLMModel`] similarly lets you call 100+ different models and providers through [LiteLLM](https://docs.litellm.ai/)!
- [`AzureOpenAIServerModel`] allows you to use OpenAI models deployed in [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service).
- [`MLXModel`] creates a [mlx-lm](https://pypi.org/project/mlx-lm/) pipeline to run inference on your local machine.

- `tools`, a list of `Tools` that the agent can use to solve the task. It can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.

Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), or [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service).
Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service), or [mlx-lm](https://pypi.org/project/mlx-lm/).

<hfoptions id="Pick a LLM">
<hfoption id="HF Inference API">
Expand Down Expand Up @@ -148,6 +149,19 @@ agent.run(
)
```

</hfoption>
<hfoption id="mlx-lm">

```python
# !pip install smolagents[mlx-lm]
from smolagents import CodeAgent, MLXModel

mlx_model = MLXModel("mlx-community/Qwen2.5-Coder-32B-Instruct-4bit")
agent = CodeAgent(model=mlx_model, tools=[], add_base_tools=True)

agent.run("Could you give me the 118th number in the Fibonacci sequence?")
```

</hfoption>
</hfoptions>

Expand Down
21 changes: 20 additions & 1 deletion docs/source/en/reference/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,23 @@ model = AzureOpenAIServerModel(
)
```

[[autodoc]] AzureOpenAIServerModel
[[autodoc]] AzureOpenAIServerModel

### MLXModel


```python
from smolagents import MLXModel

model = MLXModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct")

print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
```
```text
>>> What a
```

> [!TIP]
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.
[[autodoc]] MLXModel
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ mcp = [
"mcpadapt>=0.0.6",
"mcp",
]
mlx-lm = [
"mlx-lm"
]
openai = [
"openai>=1.58.1"
]
Expand Down
124 changes: 124 additions & 0 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from enum import Enum
Expand Down Expand Up @@ -415,6 +416,128 @@ def __call__(
return message


class MLXModel(Model):
"""A class to interact with models loaded using MLX on Apple silicon.
> [!TIP]
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.
Parameters:
model_id (str):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
tool_name_key (str):
The key, which can usually be found in the model's chat template, for retrieving a tool name.
tool_arguments_key (str):
The key, which can usually be found in the model's chat template, for retrieving tool arguments.
trust_remote_code (bool):
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
kwargs (dict, *optional*):
Any additional keyword arguments that you want to use in model.generate(), for instance `max_tokens`.
Example:
```python
>>> engine = MLXModel(
... model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit",
... max_tokens=10000,
... )
>>> messages = [
... {
... "role": "user",
... "content": [
... {"type": "text", "text": "Explain quantum mechanics in simple terms."}
... ]
... }
... ]
>>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
```
"""

def __init__(
self,
model_id: str,
tool_name_key: str = "name",
tool_arguments_key: str = "arguments",
trust_remote_code: bool = False,
**kwargs,
):
super().__init__(**kwargs)
if not _is_package_available("mlx_lm"):
raise ModuleNotFoundError(
"Please install 'mlx-lm' extra to use 'MLXModel': `pip install 'smolagents[mlx-lm]'`"
)
import mlx_lm

self.model_id = model_id
self.model, self.tokenizer = mlx_lm.load(model_id, tokenizer_config={"trust_remote_code": trust_remote_code})
self.stream_generate = mlx_lm.stream_generate
self.tool_name_key = tool_name_key
self.tool_arguments_key = tool_arguments_key

def _to_message(self, text, tools_to_call_from):
if tools_to_call_from:
# tmp solution for extracting tool JSON without assuming a specific model output format
maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}"
parsed_text = json.loads(maybe_json)
tool_name = parsed_text.get(self.tool_name_key, None)
tool_arguments = parsed_text.get(self.tool_arguments_key, None)
if tool_name:
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id=uuid.uuid4(),
type="function",
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
)
],
)
return ChatMessage(role="assistant", content=text)

def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
completion_kwargs = self._prepare_completion_kwargs(
flatten_messages_as_text=True, # mlx-lm doesn't support vision models
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
**kwargs,
)
messages = completion_kwargs.pop("messages")
prepared_stop_sequences = completion_kwargs.pop("stop", [])
tools = completion_kwargs.pop("tools", None)
completion_kwargs.pop("tool_choice", None)

prompt_ids = self.tokenizer.apply_chat_template(
messages,
tools=tools,
add_generation_prompt=True,
)

self.last_input_token_count = len(prompt_ids)
self.last_output_token_count = 0
text = ""

for _ in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs):
self.last_output_token_count += 1
text += _.text
for stop_sequence in prepared_stop_sequences:
if text.strip().endswith(stop_sequence):
text = text[: -len(stop_sequence)]
return self._to_message(text, tools_to_call_from)

return self._to_message(text, tools_to_call_from)


class TransformersModel(Model):
"""A class that uses Hugging Face's Transformers library for language model interaction.
Expand Down Expand Up @@ -837,6 +960,7 @@ def __init__(
"tool_role_conversions",
"get_clean_message_list",
"Model",
"MLXModel",
"TransformersModel",
"HfApiModel",
"LiteLLMModel",
Expand Down
10 changes: 9 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import json
import os
import sys
import unittest
from pathlib import Path
from typing import Optional
Expand All @@ -22,7 +23,7 @@
import pytest
from transformers.testing_utils import get_tests_dir

from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed


Expand Down Expand Up @@ -61,6 +62,13 @@ def test_get_hfapi_message_no_tool_external_provider(self):
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"])

@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
def test_get_mlx_message_no_tool(self):
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
output = model(messages, stop_sequences=["great"]).content
assert output.startswith("Hello")

def test_transformers_message_no_tool(self):
model = TransformersModel(
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
Expand Down

0 comments on commit 9b96199

Please sign in to comment.