You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I noticed, when trying to train Llama3.1-8B with the RewardTrainer for my own problem, that seemingly no matter what I tried I couldn't get it to converge. Simply by switching Llama3.1 for Qwen2.5 (tried sizes from 0.5B --> 7B), the model converged without issue. I kept all hyperparameters the same and used the same data (unfortunately I cannot share it).
To reproduce, I ran the official reward_modeling.py script examples with Qwen2-0.5B vs. Llama3.1-1B (both instruct):
I did have to add the line tokenizer.pad_token = tokenizer.eos_token in reward_modeling.py right after the tokenizer and model initialization, similar to how its done in sft.py, because llama does not have a pad token (auxiliary point: maybe the reward_modeling.py script should do this when there is no pad token? happy to PR if so)
Sure enough, I see the same issues with convergence:
I don't know if this is a known issue but I wanted to flag incase it is (and someone knows the fix) or it isn't and I am doing something dumb someone would be kind enough to point out!
Huh, it looks like it comes down to what you use as the pad token itself. If I used one of Llamas unused special tokens: <|reserved_special_token_0|> (pad token id 128002), it works! If I use the EOS token as padding, it doesn't...
Another callout is that in either cases, it looks like we somehow end up with two BOS tokens in the chosen/rejected pairs:
Reproduction
I noticed, when trying to train Llama3.1-8B with the
RewardTrainer
for my own problem, that seemingly no matter what I tried I couldn't get it to converge. Simply by switching Llama3.1 for Qwen2.5 (tried sizes from 0.5B --> 7B), the model converged without issue. I kept all hyperparameters the same and used the same data (unfortunately I cannot share it).To reproduce, I ran the official
reward_modeling.py
script examples with Qwen2-0.5B vs. Llama3.1-1B (both instruct):Both were trained on 1 node of 8xA6000s with the following accelerate config:
Note
I did have to add the line
tokenizer.pad_token = tokenizer.eos_token
inreward_modeling.py
right after the tokenizer and model initialization, similar to how its done insft.py
, because llama does not have a pad token (auxiliary point: maybe thereward_modeling.py
script should do this when there is no pad token? happy to PR if so)Sure enough, I see the same issues with convergence:
I don't know if this is a known issue but I wanted to flag incase it is (and someone knows the fix) or it isn't and I am doing something dumb someone would be kind enough to point out!
System Info
Checklist
The text was updated successfully, but these errors were encountered: