Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLX model support #300

Merged
merged 22 commits into from
Feb 12, 2025
Merged

Conversation

g-eoj
Copy link
Contributor

@g-eoj g-eoj commented Jan 21, 2025

The goal of this PR is to enable users to run smolagents with models loaded onto Apple silicon with mlx-lm. The mlx-community has made available many models for experimentation. Personally I find running locally to be a convenient way to learn and experiment with the smolagents library, so I made this PR for a possible new feature.

Example usage:

from smolagents.models import MLXModel

mlx_model = MLXModel("mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", max_tokens=10000)
messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
print(mlx_model(messages))

Some questions:

  • tests won't work for CICD due to hardware requirements, what is the preferred way to handle that?
  • anything needed for docs and if so where should it go?

@g-eoj g-eoj force-pushed the g-eoj/mlx-model-support branch from b193f27 to 662a481 Compare January 22, 2025 16:50
@g-eoj
Copy link
Contributor Author

g-eoj commented Jan 26, 2025

@kingdomad @clefourrier as you are reviewing #337, can you please take a look at this PR?

@aymeric-roucher
Copy link
Collaborator

Sorry for late review @g-eoj , I just left some comments! 😃

@g-eoj
Copy link
Contributor Author

g-eoj commented Feb 3, 2025

Hi @aymeric-roucher can you please take a look? There are no comments or review from you.

I can add docs when/if design is finalized.

@sysradium
Copy link
Contributor

sysradium commented Feb 4, 2025

Works for me. Though a slight correction to the example is:

mlx_model = MLXModel(
    "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit",
    max_tokens=10000,
)
messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Explain quantum mechanics in simple terms."}
        ],
    }
]
print(mlx_model(messages))

I guess will be a good addition. You still can use LiteLLMModel with LMStudio to achieve similar goal, however not needing to run LMStudio seems quite convenient.

src/smolagents/models.py Outdated Show resolved Hide resolved
src/smolagents/models.py Outdated Show resolved Hide resolved
@g-eoj g-eoj requested a review from sysradium February 4, 2025 20:27
Copy link
Contributor

@sysradium sysradium left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it looks good.

@sysradium
Copy link
Contributor

@g-eoj is there any downside to checking for a stop_sequence just in a received chunk? Because if not, then the performance could be slightly improved.

For example if you take those options:

def implementation_1(text, stop_sequences):
    text_accumulated = ""
    for chunk in text:
        text_accumulated += chunk
        for stop_sequence in stop_sequences:
            if text_accumulated.strip().endswith(stop_sequence):
                text_accumulated = text_accumulated[: -len(stop_sequence)]
                return text_accumulated  # Simulating _to_message call
    return text_accumulated


def implementation_2(text, stop_sequences):
    text_accumulated = ""
    for chunk in text:
        text_accumulated += chunk
        for stop_sequence in stop_sequences:
            if chunk.endswith(stop_sequence):
                text_accumulated = text_accumulated[: -len(stop_sequence)]
                return text_accumulated  # Simulating _to_message call
    return text_accumulated


def implementation_3(text, stop_sequences):
    text_accumulated = []
    for chunk in text:
        text_accumulated.append(chunk)
        for stop_sequence in stop_sequences:
            if chunk.endswith(stop_sequence):
                joined_text = "".join(text_accumulated)[: -len(stop_sequence)]
                return joined_text
    return "".join(text_accumulated)

def implementation_5(text, stop_sequences):
    text_accumulated = []
    for chunk in text:
        text_accumulated.append(chunk)
        if not chunk.endswith(tuple(stop_sequences)):
            continue

        for stop_sequence in stop_sequences:
            if chunk.endswith(stop_sequence):
                joined_text = "".join(text_accumulated)[: -len(stop_sequence)]
                return joined_text

    return "".join(text_accumulated)


def implementation_6(text, stop_sequences):
    text_accumulated = []
    for chunk in text:
        text_accumulated.append(chunk)
        if not chunk.endswith(tuple(stop_sequences)):
            continue

        matched_suffix = next(s for s in stop_sequences if chunk.endswith(s))
        return "".join(text_accumulated)[: -len(matched_suffix)]

    return "".join(text_accumulated)

The benchmark would result in:
image
Might be useful in intensive apps.

@g-eoj
Copy link
Contributor Author

g-eoj commented Feb 9, 2025

@sysradium what is a chunk in this case? I'm pretty sure mlx-lm streams always produce one token at a time - could it still be used to get chunks efficiently? It seems like you'd have to have a guarantee the stop string didn't get split between chunks too.

I'm all for making this more efficient.

@sysradium
Copy link
Contributor

sysradium commented Feb 9, 2025

@g-eoj the chunk is your_.text. So it is whatever mlx-lm decided to yield :) I had an assumption that a stop sequence is always a single token (or a single unit of what mlx-lm yields).

They themselves implement a generate function which does exactly what you did https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py#L594, i.e. just appending to the text variable. But it is quite inefficient since strings are immutable in python.

Not sure if they did it like this because they know something, or don't care about performance. I thought maybe you checked that when was working with the implementation.

@g-eoj
Copy link
Contributor Author

g-eoj commented Feb 9, 2025

I had an assumption that a stop sequence is always a single token (or a single unit of what mlx-lm yields)

I think the smolagent stop sequences (for example

stop_sequences=["<end_code>", "Observation:"],
) have the potential to be composed of multiple tokens based on the tokenizer.

To your point about strings being immutable - I assumed https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py#L594 was okay since Apple is doing it. I didn't check for a good reason (in Apple's case it seems to just be to support verbosity). I can't think of any reason for appending to the text variable except for the need to check multi-token stop sequences.

@g-eoj
Copy link
Contributor Author

g-eoj commented Feb 10, 2025

For reference, I tried this approach for not appending to the text variable. I didn't find evidence of a speed up.

def check_stop_1(messages, stop_sequences):
    prompt_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
    )
    outputs = []
    for step in mlx_lm.stream_generate(model, tokenizer, prompt=prompt_ids, max_tokens=10000):
        outputs.append(step)
        for stop_sequence in stop_sequences:
            # assumes stop sequence will never be more than 10 tokens
            recent_text = "".join([_.text for _ in outputs[-10:]]) 
            if recent_text.rstrip().endswith(stop_sequence):
                text = "".join([_.text for _ in outputs])
                text = text.rstrip()[:-len(stop_sequence)]
                return text         
    text = "".join([_.text for _ in outputs])
    return text

@g-eoj
Copy link
Contributor Author

g-eoj commented Feb 10, 2025

Gonna move this work over to https://github.com/g-eoj/mac-smolagents. Still happy to contribute but I'm skeptical this PR makes it in.

@sysradium
Copy link
Contributor

Which is a pitty. I use it locally quite often now :/

@g-eoj
Copy link
Contributor Author

g-eoj commented Feb 10, 2025

Just so there is no confusion, you can still use it alongside smolagents. You'll just need to install an extra whl and change your code a bit. If for some reason this doesn't work, please make an issue and I'll try to fix it.

import mac_smolagents
import smolagents


mlx_language_model = mac_smolagents.MLXLModel(
    model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit"
)
agent = smolagents.CodeAgent(
    model=mlx_language_model, tools=[], add_base_tools=True
)
agent.run(...

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Parameters:
model_id (str):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
tool_name_key (str):
Copy link
Collaborator

@aymeric-roucher aymeric-roucher Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this tool_name_key in TransformersModel, what is the reason for needing it in MLXModel?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually upon inspection it seems like a good idea, let's keep it and we might set the same in TransformersModel later on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the params to be required unless I was using a logits processor and regex to force the key names in the output. Which might be a better solution overall but I have no strong opinions yet.

@aymeric-roucher
Copy link
Collaborator

aymeric-roucher commented Feb 12, 2025

@g-eoj sorry for late review, don't hesitate to ping again when this happens!
Also please run ruff format and ruff check on your PR to fix formatting and pass tests!

@aymeric-roucher aymeric-roucher merged commit 9b96199 into huggingface:main Feb 12, 2025
3 of 4 checks passed
@aymeric-roucher
Copy link
Collaborator

And thank you for the contribution, great work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants