-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MLX model support #300
MLX model support #300
Conversation
b193f27
to
662a481
Compare
@kingdomad @clefourrier as you are reviewing #337, can you please take a look at this PR? |
Sorry for late review @g-eoj , I just left some comments! 😃 |
Hi @aymeric-roucher can you please take a look? There are no comments or review from you. I can add docs when/if design is finalized. |
Works for me. Though a slight correction to the example is: mlx_model = MLXModel(
"mlx-community/Qwen2.5-Coder-32B-Instruct-4bit",
max_tokens=10000,
)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Explain quantum mechanics in simple terms."}
],
}
]
print(mlx_model(messages)) I guess will be a good addition. You still can use LiteLLMModel with LMStudio to achieve similar goal, however not needing to run LMStudio seems quite convenient. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me it looks good.
@g-eoj is there any downside to checking for a For example if you take those options: def implementation_1(text, stop_sequences):
text_accumulated = ""
for chunk in text:
text_accumulated += chunk
for stop_sequence in stop_sequences:
if text_accumulated.strip().endswith(stop_sequence):
text_accumulated = text_accumulated[: -len(stop_sequence)]
return text_accumulated # Simulating _to_message call
return text_accumulated
def implementation_2(text, stop_sequences):
text_accumulated = ""
for chunk in text:
text_accumulated += chunk
for stop_sequence in stop_sequences:
if chunk.endswith(stop_sequence):
text_accumulated = text_accumulated[: -len(stop_sequence)]
return text_accumulated # Simulating _to_message call
return text_accumulated
def implementation_3(text, stop_sequences):
text_accumulated = []
for chunk in text:
text_accumulated.append(chunk)
for stop_sequence in stop_sequences:
if chunk.endswith(stop_sequence):
joined_text = "".join(text_accumulated)[: -len(stop_sequence)]
return joined_text
return "".join(text_accumulated)
def implementation_5(text, stop_sequences):
text_accumulated = []
for chunk in text:
text_accumulated.append(chunk)
if not chunk.endswith(tuple(stop_sequences)):
continue
for stop_sequence in stop_sequences:
if chunk.endswith(stop_sequence):
joined_text = "".join(text_accumulated)[: -len(stop_sequence)]
return joined_text
return "".join(text_accumulated)
def implementation_6(text, stop_sequences):
text_accumulated = []
for chunk in text:
text_accumulated.append(chunk)
if not chunk.endswith(tuple(stop_sequences)):
continue
matched_suffix = next(s for s in stop_sequences if chunk.endswith(s))
return "".join(text_accumulated)[: -len(matched_suffix)]
return "".join(text_accumulated) The benchmark would result in: |
@sysradium what is a I'm all for making this more efficient. |
@g-eoj the chunk is your They themselves implement a generate function which does exactly what you did https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py#L594, i.e. just appending to the text variable. But it is quite inefficient since strings are immutable in python. Not sure if they did it like this because they know something, or don't care about performance. I thought maybe you checked that when was working with the implementation. |
I think the smolagent stop sequences (for example smolagents/src/smolagents/agents.py Line 849 in d74837b
To your point about strings being immutable - I assumed https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py#L594 was okay since Apple is doing it. I didn't check for a good reason |
For reference, I tried this approach for not appending to the text variable. I didn't find evidence of a speed up.
|
Gonna move this work over to https://github.com/g-eoj/mac-smolagents. Still happy to contribute but I'm skeptical this PR makes it in. |
Which is a pitty. I use it locally quite often now :/ |
Just so there is no confusion, you can still use it alongside
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Parameters: | ||
model_id (str): | ||
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. | ||
tool_name_key (str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need this tool_name_key
in TransformersModel
, what is the reason for needing it in MLXModel
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually upon inspection it seems like a good idea, let's keep it and we might set the same in TransformersModel
later on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found the params to be required unless I was using a logits processor and regex to force the key names in the output. Which might be a better solution overall but I have no strong opinions yet.
@g-eoj sorry for late review, don't hesitate to ping again when this happens! |
And thank you for the contribution, great work! |
The goal of this PR is to enable users to run
smolagents
with models loaded onto Apple silicon with mlx-lm. The mlx-community has made available many models for experimentation. Personally I find running locally to be a convenient way to learn and experiment with thesmolagents
library, so I made this PR for a possible new feature.Example usage:
Some questions:tests won't work for CICD due to hardware requirements, what is the preferred way to handle that?anything needed for docs and if so where should it go?