Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding grpo training #1233

Open
wants to merge 54 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
5e0ae83
initial commit, gn
Goekdeniz-Guelmez Jan 28, 2025
b1e573d
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Jan 29, 2025
93370ff
updates ans fixing the KL div lines
Goekdeniz-Guelmez Jan 30, 2025
6c58aa9
updates
Goekdeniz-Guelmez Jan 31, 2025
80bcf68
grpo_trainer shoudl be done
Goekdeniz-Guelmez Jan 31, 2025
a57d553
update
Goekdeniz-Guelmez Jan 31, 2025
243c962
update lora.py
Goekdeniz-Guelmez Jan 31, 2025
d034ca3
adding function for R1
Goekdeniz-Guelmez Feb 3, 2025
734d6f4
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 3, 2025
a3ed632
dataset wrapper done
Goekdeniz-Guelmez Feb 3, 2025
41ff536
Merge branch 'adding-GRPO-training' of https://github.com/Goekdeniz-G…
Goekdeniz-Guelmez Feb 3, 2025
23d75cd
starting fist training test run
Goekdeniz-Guelmez Feb 3, 2025
1d9e480
first working prototype, will try training out at home
Goekdeniz-Guelmez Feb 3, 2025
05d921b
optims
Goekdeniz-Guelmez Feb 3, 2025
40bca77
fixes
Goekdeniz-Guelmez Feb 3, 2025
06f9c29
print func name
Goekdeniz-Guelmez Feb 3, 2025
54e295e
fix name funcs
Goekdeniz-Guelmez Feb 3, 2025
ca32424
updates
Goekdeniz-Guelmez Feb 3, 2025
7173840
first succesfull training run
Goekdeniz-Guelmez Feb 4, 2025
bd1a42e
adding args into dataset handling
Goekdeniz-Guelmez Feb 4, 2025
7b01414
better create_dataset
Goekdeniz-Guelmez Feb 4, 2025
0a09a93
fix cache handling
Goekdeniz-Guelmez Feb 5, 2025
2a8e6f6
udpate
Goekdeniz-Guelmez Feb 5, 2025
d84ad0c
fix testing
Goekdeniz-Guelmez Feb 5, 2025
a33cad8
udpates
Goekdeniz-Guelmez Feb 5, 2025
35a2d99
smoll fix
Goekdeniz-Guelmez Feb 5, 2025
0a19522
updates
Goekdeniz-Guelmez Feb 5, 2025
bcfa55d
updates
Goekdeniz-Guelmez Feb 5, 2025
94dcd0f
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 6, 2025
9ba6146
fix
Goekdeniz-Guelmez Feb 9, 2025
39e9469
freeze ref model
Goekdeniz-Guelmez Feb 9, 2025
5417990
fix
Goekdeniz-Guelmez Feb 9, 2025
a527cdb
fix: prevent gradients from flowing through the reference model's logits
Goekdeniz-Guelmez Feb 9, 2025
0071252
rebase loss calculation
Goekdeniz-Guelmez Feb 9, 2025
0dac286
Merge branch 'main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 10, 2025
d9da35f
nits
Goekdeniz-Guelmez Feb 10, 2025
f88e897
removing helper functions
Goekdeniz-Guelmez Feb 10, 2025
e5aa2c3
nits
Goekdeniz-Guelmez Feb 10, 2025
b7bc811
nits
Goekdeniz-Guelmez Feb 10, 2025
88ca747
nits
Goekdeniz-Guelmez Feb 10, 2025
e96afe9
updates
Goekdeniz-Guelmez Feb 11, 2025
e80bf95
fix
Goekdeniz-Guelmez Feb 11, 2025
35ecc17
fix
Goekdeniz-Guelmez Feb 11, 2025
978deab
small fix
Goekdeniz-Guelmez Feb 11, 2025
5aeefc8
update new iterade batches function + nits
Goekdeniz-Guelmez Feb 12, 2025
c42e858
Merge branch 'adding-GRPO-training' of https://github.com/Goekdeniz-G…
Goekdeniz-Guelmez Feb 12, 2025
e33d9d5
updates
Goekdeniz-Guelmez Feb 12, 2025
3823154
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 12, 2025
a7273f6
small fix
Goekdeniz-Guelmez Feb 12, 2025
8179b99
quick prompting fix
Goekdeniz-Guelmez Feb 12, 2025
65a49dd
nits
Goekdeniz-Guelmez Feb 13, 2025
baeb9f1
reduncancy fix + nits
Goekdeniz-Guelmez Feb 14, 2025
5ec4790
removing comments + adding temperature + reward weighting
Goekdeniz-Guelmez Feb 15, 2025
6a6bd53
removing print and switching some variables in the math
Goekdeniz-Guelmez Feb 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 175 additions & 37 deletions llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# Copyright © 2024 Apple Inc.

from pathlib import Path
import argparse
import types
import math
import os
import re
import types
from pathlib import Path

import mlx.nn as nn
import mlx.optimizers as optim
import mlx.nn as nn
import numpy as np
import yaml

from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
Expand Down Expand Up @@ -42,6 +43,7 @@
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"training_mode": "normal",
"fine_tune_type": "lora",
"data": "data/",
"seed": 0,
Expand All @@ -62,6 +64,17 @@
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},

# GRPO args
"reference_model_path": None,
"group_size": 4,
"beta": 0.1,
"epsilon": 1e-4,
"max_completion_length": 512,
"use_chat_template": False,
"use_prompt": False,
"temperature": 1.0,
"reward_weights": None,
}


Expand Down Expand Up @@ -102,6 +115,12 @@ def build_parser():
default=False,
)

parser.add_argument(
"--training-mode",
type=str,
choices=["normal", "grpo"],
help="Training mode: normal or GRPO",
)
parser.add_argument(
"--num-layers",
type=int,
Expand Down Expand Up @@ -169,8 +188,93 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")

# GRPO args
parser.add_argument(
"--group-size",
type=int,
help="Number of generations.",
default=4,
)
parser.add_argument(
"--max-completion-length",
type=int,
help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
default=512,
)
parser.add_argument(
"--beta",
type=float,
help="KL penalty coefficient.",
default=0.1,
)
parser.add_argument(
"--epsilon",
type=float,
help="The Epsilon for numerical stability.",
default=1e-4,
)
parser.add_argument(
"--use-chat-template",
action="store_true",
help="If the model is a Chat model, use the Chat template.",
default=None,
)
parser.add_argument(
"--use-prompt",
action="store_true",
help="Rather to use the prompt from the R1 paper.",
default=None,
)
parser.add_argument(
"--temperature",
type=float,
help="Temperature for sampling. The higher the temperature, the more random the completions.",
default=1.0,
)
parser.add_argument(
"--reward-weights",
type=str,
help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.",
default=None,
)
return parser

def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
training_args = GRPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
max_completion_length=args.max_completion_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path,
temperature=args.temperature,
reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None
)

if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)

train_grpo(
model=model,
ref_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)

def train_model(
args,
Expand Down Expand Up @@ -208,19 +312,6 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")

# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)

model.train()
opt = optim.Adam(
learning_rate=(
Expand All @@ -229,32 +320,79 @@ def train_model(
)

# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
if args.training_mode == "grpo":
Goekdeniz-Guelmez marked this conversation as resolved.
Show resolved Hide resolved
train_model_grpo(
model,
tokenizer,
args,
opt,
train_set,
valid_set,
adapter_file,
training_callback
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
Goekdeniz-Guelmez marked this conversation as resolved.
Show resolved Hide resolved
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint
)

train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)


def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()

test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
if args.training_mode == "grpo":
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)

test_loss, _, test_rewards = evaluate_grpo(
model=model,
ref_model=reference_model.freeze(),
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon
)

test_ppl = math.exp(test_loss)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else:
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)

test_ppl = math.exp(test_loss)
test_ppl = math.exp(test_loss)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")


def run(args, training_callback: TrainingCallback = None):
Expand Down Expand Up @@ -305,4 +443,4 @@ def main():


if __name__ == "__main__":
main()
main()
Loading