Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLX model support #300

Merged
merged 22 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion docs/source/en/guided_tour.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ To initialize a minimal agent, you need at least these two arguments:
- [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood and supports all Inference Providers on the Hub.
- [`LiteLLMModel`] similarly lets you call 100+ different models and providers through [LiteLLM](https://docs.litellm.ai/)!
- [`AzureOpenAIServerModel`] allows you to use OpenAI models deployed in [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service).
- [`MLXModel`] takes a pre-initialized `mlx-lm` pipeline to run inference on your local machine using [mlx-lm](https://pypi.org/project/mlx-lm/).
aymeric-roucher marked this conversation as resolved.
Show resolved Hide resolved

- `tools`, a list of `Tools` that the agent can use to solve the task. It can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.

Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), or [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service).
Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service), or [mlx-lm](https://pypi.org/project/mlx-lm/).

<hfoptions id="Pick a LLM">
<hfoption id="HF Inference API">
Expand Down Expand Up @@ -148,6 +149,19 @@ agent.run(
)
```

</hfoption>
<hfoption id="mlx-lm">

```python
# !pip install smolagents[mlx-lm]
from smolagents import CodeAgent, MLXModel

mlx_model = MLXModel("mlx-community/Qwen2.5-Coder-32B-Instruct-4bit")
agent = CodeAgent(model=mlx_model, tools=[], add_base_tools=True)

agent.run("Could you give me the 118th number in the Fibonacci sequence?")
```

</hfoption>
</hfoptions>

Expand Down
22 changes: 21 additions & 1 deletion docs/source/en/reference/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,24 @@ model = AzureOpenAIServerModel(
)
```

[[autodoc]] AzureOpenAIServerModel
[[autodoc]] AzureOpenAIServerModel

### MLXModel

aymeric-roucher marked this conversation as resolved.
Show resolved Hide resolved
For convenience, we have added a `MLXModel` that implements the points above by building a local `mlx-lm` pipeline for the model_id given at initialization.

```python
from smolagents import MLXModel

model = MLXModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct")

print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
```
```text
>>> What a
```

> [!TIP]
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.

[[autodoc]] MLXModel
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ mcp = [
"mcpadapt>=0.0.6",
"mcp",
]
mlx-lm = [
"mlx-lm"
]
openai = [
"openai>=1.58.1"
]
Expand Down
124 changes: 124 additions & 0 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from enum import Enum
Expand Down Expand Up @@ -410,6 +411,128 @@ def __call__(
return message


class MLXModel(Model):
"""A class to interact with models loaded using MLX on Apple silicon.

> [!TIP]
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.

Parameters:
model_id (str):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
tool_name_key (str):
Copy link
Collaborator

@aymeric-roucher aymeric-roucher Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this tool_name_key in TransformersModel, what is the reason for needing it in MLXModel?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually upon inspection it seems like a good idea, let's keep it and we might set the same in TransformersModel later on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the params to be required unless I was using a logits processor and regex to force the key names in the output. Which might be a better solution overall but I have no strong opinions yet.

The key, which can usually be found in the model's chat template, for retrieving a tool name.
tool_arguments_key (str):
The key, which can usually be found in the model's chat template, for retrieving tool arguments.
trust_remote_code (bool):
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
kwargs (dict, *optional*):
Any additional keyword arguments that you want to use in model.generate(), for instance `max_tokens`.

Example:
```python
>>> engine = MLXModel(
... model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit",
... max_tokens=10000,
... )
>>> messages = [
... {
... "role": "user",
... "content": [
... {"type": "text", "text": "Explain quantum mechanics in simple terms."}
... ]
... }
... ]
>>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
```
"""

def __init__(
self,
model_id: str,
tool_name_key: str = "name",
tool_arguments_key: str = "arguments",
trust_remote_code: bool = False,
**kwargs,
):
super().__init__(**kwargs)
if not _is_package_available("mlx_lm"):
raise ModuleNotFoundError(
"Please install 'mlx-lm' extra to use 'MLXModel': `pip install 'smolagents[mlx-lm]'`"
)
import mlx_lm

self.model_id = model_id
self.model, self.tokenizer = mlx_lm.load(model_id, tokenizer_config={"trust_remote_code": trust_remote_code})
self.stream_generate = mlx_lm.stream_generate
self.tool_name_key = tool_name_key
self.tool_arguments_key = tool_arguments_key

def _to_message(self, text, tools_to_call_from):
if tools_to_call_from:
# tmp solution for extracting tool JSON without assuming a specific model output format
maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}"
parsed_text = json.loads(maybe_json)
tool_name = parsed_text.get(self.tool_name_key, None)
tool_arguments = parsed_text.get(self.tool_arguments_key, None)
if tool_name:
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id=uuid.uuid4(),
type="function",
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
)
],
)
return ChatMessage(role="assistant", content=text)

def __call__(
self,
g-eoj marked this conversation as resolved.
Show resolved Hide resolved
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
completion_kwargs = self._prepare_completion_kwargs(
flatten_messages_as_text=True, # mlx-lm doesn't support vision models
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
**kwargs,
)
messages = completion_kwargs.pop("messages")
prepared_stop_sequences = completion_kwargs.pop("stop", [])
tools = completion_kwargs.pop("tools", None)
completion_kwargs.pop("tool_choice", None)

prompt_ids = self.tokenizer.apply_chat_template(
messages,
tools=tools,
add_generation_prompt=True,
)

self.last_input_token_count = len(prompt_ids)
self.last_output_token_count = 0
text = ""

for _ in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs):
self.last_output_token_count += 1
text += _.text
for stop_sequence in prepared_stop_sequences:
if text.strip().endswith(stop_sequence):
text = text[: -len(stop_sequence)]
return self._to_message(text, tools_to_call_from)

return self._to_message(text, tools_to_call_from)


class TransformersModel(Model):
"""A class that uses Hugging Face's Transformers library for language model interaction.

Expand Down Expand Up @@ -832,6 +955,7 @@ def __init__(
"tool_role_conversions",
"get_clean_message_list",
"Model",
"MLXModel",
"TransformersModel",
"HfApiModel",
"LiteLLMModel",
Expand Down
10 changes: 9 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import json
import os
import sys
import unittest
from pathlib import Path
from typing import Optional
Expand All @@ -22,7 +23,7 @@
import pytest
from transformers.testing_utils import get_tests_dir

from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool
from smolagents.models import get_clean_message_list, parse_json_if_needed


Expand Down Expand Up @@ -60,6 +61,13 @@ def test_get_hfapi_message_no_tool_external_provider(self):
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"])

@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
def test_get_mlx_message_no_tool(self):
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
output = model(messages, stop_sequences=["great"]).content
assert output.startswith("Hello")

def test_transformers_message_no_tool(self):
model = TransformersModel(
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
Expand Down
Loading