-
Notifications
You must be signed in to change notification settings - Fork 972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding grpo training #1233
base: main
Are you sure you want to change the base?
Adding grpo training #1233
Conversation
Absolute HERO! Been trying to figure this out myself the past week but made pretty much no progress whatsoever, other than to make a script that fills up all the RAM on my Mac 🤣 Is there any way to run this yet? I assume no since at the mo it's still marked as in draft + there isn't a lora_config.yaml like in the DPO example yet (not sure if it's needed)? |
No, not yet I still have to implement the Dataset Wrapper and some other stuff, I'll tell you when it's done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible need to use expanded_prompts, expanded_answers in both reward and loss
python -m mlx_lm.lora \
--model Qwen/Qwen2.5-0.5B \
--train \
--data /Users/gokdenizgulmez/Desktop/test_grpo \
--iters 5 \
--batch-size 1 \
--num-layers 4 \
--val-batches 1 \
--steps-per-report 1 \
--adapter-path /Users/gokdenizgulmez/Desktop/test-grpo-full \
--max-seq-length 128 \
--grad-checkpoint \
--training-mode grpo \
--fine-tune-type lora \
--beta 0.1 \
--steps-per-eval 500 \
--group-size 2 Output
But after that my 32 GB of ram get fully used. I tried to add some memory optimisations but the memory usage is still too much. |
Still uses too much memory. |
So I tried using trl and the same amount of ram has been used, so no error on my side |
🚀 Would you be able to share the datasets you used for the training? Will give it a go on my machine as soon as I can 🙌 |
Will do that tomorrow 🤝 |
I created a quick one only for testing the code |
python -m mlx_lm.lora \
--model Qwen/Qwen2.5-0.5B \
--train \
--data /Users/gokdenizgulmez/Desktop/test_grpo \
--iters 5 \
--batch-size 1 \
--num-layers 8 \
--val-batches 1 \
--steps-per-report 1 \
--adapter-path /Users/gokdenizgulmez/Desktop/test-grpo-full \
--max-seq-length 255 \
--grad-checkpoint \
--training-mode grpo \
--fine-tune-type lora \
--beta 0.1 \
--steps-per-eval 500 \
--group-size 2 \
--max-completion-length 6 Output:
|
First successful training run (I started it yesterday evening, with wen 1.5 instruct):
After:--prompt "give me a cool math proof"
--prompt "what is 150 * 8"
|
I'm starting a new one with the base model and the new commits |
congratulation , by the way, what machine spec are u using? |
M4 Mac mini 32 GB |
by the way, any chance you can share you dataset want to have a loook |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using gpu possible save some main memory
Model: Before:
After:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Easier to read code 1/n
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mainly adding 3 types :
RewardFunction GRPOExample GRPOBatch
Thanks @Guo-astro however this did make the computation sky rocket Val from 70s - 80s to 130s - 150s and training has the same too, probably due to copying data multiple times through class instantiation. |
True. Then I think we need to to use those as few as possible. Python is managing all the class instance memories so it could be really slow😅 |
I think I'll use a hybrid approach with your suggestions, because they make it more stable, maintainable, and easier to debug and test. Thanks for your help!!!!!! |
Cannot load the existing dataset on HF, the following error was found when using Goastro/mlx-grpo-dataset for testing. input:
output:
|
Should be fixed now I also suggest you to use the |
@mark-lord should be able to run it now!! If you want to use a base model you can use the |
No description provided.