diff --git a/docs/source/en/guided_tour.md b/docs/source/en/guided_tour.md index ffdac9b47..dd8d5f287 100644 --- a/docs/source/en/guided_tour.md +++ b/docs/source/en/guided_tour.md @@ -28,10 +28,11 @@ To initialize a minimal agent, you need at least these two arguments: - [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood and supports all Inference Providers on the Hub. - [`LiteLLMModel`] similarly lets you call 100+ different models and providers through [LiteLLM](https://docs.litellm.ai/)! - [`AzureOpenAIServerModel`] allows you to use OpenAI models deployed in [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service). + - [`MLXModel`] creates a [mlx-lm](https://pypi.org/project/mlx-lm/) pipeline to run inference on your local machine. - `tools`, a list of `Tools` that the agent can use to solve the task. It can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`. -Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), or [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service). +Once you have these two arguments, `tools` and `model`, you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service), or [mlx-lm](https://pypi.org/project/mlx-lm/). @@ -148,6 +149,19 @@ agent.run( ) ``` + + + +```python +# !pip install smolagents[mlx-lm] +from smolagents import CodeAgent, MLXModel + +mlx_model = MLXModel("mlx-community/Qwen2.5-Coder-32B-Instruct-4bit") +agent = CodeAgent(model=mlx_model, tools=[], add_base_tools=True) + +agent.run("Could you give me the 118th number in the Fibonacci sequence?") +``` + diff --git a/docs/source/en/reference/models.md b/docs/source/en/reference/models.md index 559b0b586..2a7f8f45d 100644 --- a/docs/source/en/reference/models.md +++ b/docs/source/en/reference/models.md @@ -147,4 +147,23 @@ model = AzureOpenAIServerModel( ) ``` -[[autodoc]] AzureOpenAIServerModel \ No newline at end of file +[[autodoc]] AzureOpenAIServerModel + +### MLXModel + + +```python +from smolagents import MLXModel + +model = MLXModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct") + +print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])) +``` +```text +>>> What a +``` + +> [!TIP] +> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case. + +[[autodoc]] MLXModel diff --git a/pyproject.toml b/pyproject.toml index a395752f6..2a333699b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,9 @@ mcp = [ "mcpadapt>=0.0.6", "mcp", ] +mlx-lm = [ + "mlx-lm" +] openai = [ "openai>=1.58.1" ] diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 28de2eefa..1d55a2e8d 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -18,6 +18,7 @@ import logging import os import random +import uuid from copy import deepcopy from dataclasses import asdict, dataclass from enum import Enum @@ -415,6 +416,128 @@ def __call__( return message +class MLXModel(Model): + """A class to interact with models loaded using MLX on Apple silicon. + + > [!TIP] + > You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case. + + Parameters: + model_id (str): + The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. + tool_name_key (str): + The key, which can usually be found in the model's chat template, for retrieving a tool name. + tool_arguments_key (str): + The key, which can usually be found in the model's chat template, for retrieving tool arguments. + trust_remote_code (bool): + Some models on the Hub require running remote code: for this model, you would have to set this flag to True. + kwargs (dict, *optional*): + Any additional keyword arguments that you want to use in model.generate(), for instance `max_tokens`. + + Example: + ```python + >>> engine = MLXModel( + ... model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", + ... max_tokens=10000, + ... ) + >>> messages = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "text", "text": "Explain quantum mechanics in simple terms."} + ... ] + ... } + ... ] + >>> response = engine(messages, stop_sequences=["END"]) + >>> print(response) + "Quantum mechanics is the branch of physics that studies..." + ``` + """ + + def __init__( + self, + model_id: str, + tool_name_key: str = "name", + tool_arguments_key: str = "arguments", + trust_remote_code: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + if not _is_package_available("mlx_lm"): + raise ModuleNotFoundError( + "Please install 'mlx-lm' extra to use 'MLXModel': `pip install 'smolagents[mlx-lm]'`" + ) + import mlx_lm + + self.model_id = model_id + self.model, self.tokenizer = mlx_lm.load(model_id, tokenizer_config={"trust_remote_code": trust_remote_code}) + self.stream_generate = mlx_lm.stream_generate + self.tool_name_key = tool_name_key + self.tool_arguments_key = tool_arguments_key + + def _to_message(self, text, tools_to_call_from): + if tools_to_call_from: + # tmp solution for extracting tool JSON without assuming a specific model output format + maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}" + parsed_text = json.loads(maybe_json) + tool_name = parsed_text.get(self.tool_name_key, None) + tool_arguments = parsed_text.get(self.tool_arguments_key, None) + if tool_name: + return ChatMessage( + role="assistant", + content="", + tool_calls=[ + ChatMessageToolCall( + id=uuid.uuid4(), + type="function", + function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments), + ) + ], + ) + return ChatMessage(role="assistant", content=text) + + def __call__( + self, + messages: List[Dict[str, str]], + stop_sequences: Optional[List[str]] = None, + grammar: Optional[str] = None, + tools_to_call_from: Optional[List[Tool]] = None, + **kwargs, + ) -> ChatMessage: + completion_kwargs = self._prepare_completion_kwargs( + flatten_messages_as_text=True, # mlx-lm doesn't support vision models + messages=messages, + stop_sequences=stop_sequences, + grammar=grammar, + tools_to_call_from=tools_to_call_from, + **kwargs, + ) + messages = completion_kwargs.pop("messages") + prepared_stop_sequences = completion_kwargs.pop("stop", []) + tools = completion_kwargs.pop("tools", None) + completion_kwargs.pop("tool_choice", None) + + prompt_ids = self.tokenizer.apply_chat_template( + messages, + tools=tools, + add_generation_prompt=True, + ) + + self.last_input_token_count = len(prompt_ids) + self.last_output_token_count = 0 + text = "" + + for _ in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs): + self.last_output_token_count += 1 + text += _.text + for stop_sequence in prepared_stop_sequences: + if text.strip().endswith(stop_sequence): + text = text[: -len(stop_sequence)] + return self._to_message(text, tools_to_call_from) + + return self._to_message(text, tools_to_call_from) + + class TransformersModel(Model): """A class that uses Hugging Face's Transformers library for language model interaction. @@ -837,6 +960,7 @@ def __init__( "tool_role_conversions", "get_clean_message_list", "Model", + "MLXModel", "TransformersModel", "HfApiModel", "LiteLLMModel", diff --git a/tests/test_models.py b/tests/test_models.py index 02105af11..0e7a42cdd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,6 +14,7 @@ # limitations under the License. import json import os +import sys import unittest from pathlib import Path from typing import Optional @@ -22,7 +23,7 @@ import pytest from transformers.testing_utils import get_tests_dir -from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool +from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed @@ -61,6 +62,13 @@ def test_get_hfapi_message_no_tool_external_provider(self): messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] model(messages, stop_sequences=["great"]) + @unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS") + def test_get_mlx_message_no_tool(self): + model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10) + messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] + output = model(messages, stop_sequences=["great"]).content + assert output.startswith("Hello") + def test_transformers_message_no_tool(self): model = TransformersModel( model_id="HuggingFaceTB/SmolLM2-135M-Instruct",