Skip to content

Commit

Permalink
🩳 max_seq_length to max_length (#2895)
Browse files Browse the repository at this point in the history
* `max_seq_length` to `max_length`

* remove in 0.20
  • Loading branch information
qgallouedec authored Feb 18, 2025
1 parent 6aaf379 commit be1e340
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 62 deletions.
2 changes: 1 addition & 1 deletion commands/run_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_seq_length $SEQ_LEN \
--max_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""

Expand Down
6 changes: 3 additions & 3 deletions docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ training_args = DPOConfig(..., max_completion_length=...)
</hfoption>
<hfoption id="SFT">

SFT truncation is applied to the input sequence via the `max_seq_length` parameter.
SFT truncation is applied to the input sequence via the `max_length` parameter.

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png" alt="Truncation input ids" width="600"/>
Expand All @@ -55,7 +55,7 @@ To set the truncation parameter, use the following code snippet:
```python
from trl import SFTConfig

training_args = SFTConfig(..., max_seq_length=...)
training_args = SFTConfig(..., max_length=...)
```

</hfoption>
Expand Down Expand Up @@ -85,7 +85,7 @@ Packing eliminates padding, preserves all sequence information, and allows for f
```python
from trl import SFTConfig

training_args = SFTConfig(..., packing=True, max_seq_length=512)
training_args = SFTConfig(..., packing=True, max_length=512)
```

<Tip warning={true}>
Expand Down
12 changes: 6 additions & 6 deletions docs/source/sft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")

training_args = SFTConfig(
max_seq_length=512,
max_length=512,
output_dir="/tmp",
)
trainer = SFTTrainer(
Expand All @@ -29,7 +29,7 @@ trainer = SFTTrainer(
)
trainer.train()
```
Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
Make sure to pass the correct value for `max_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.

You can also construct a model outside of the trainer and pass it as follows:

Expand Down Expand Up @@ -550,12 +550,12 @@ import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel

max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
max_length = 2048 # Supports automatic RoPE Scaling, so choose any number

# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b",
max_seq_length=max_seq_length,
max_seq_length=max_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
Expand All @@ -581,7 +581,7 @@ model = FastLanguageModel.get_peft_model(
random_state=3407,
)

training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length)
training_args = SFTConfig(output_dir="./output", max_length=max_length)

trainer = SFTTrainer(
model=model,
Expand Down Expand Up @@ -624,7 +624,7 @@ To learn more about Liger-Kernel, visit their [official repository](https://gith

Pay attention to the following best practices when training a model with that trainer:

- [`SFTTrainer`] always truncates by default the sequences to the `max_seq_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def create_datasets(tokenizer, args, seed=None):
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
max_seq_length=None,
max_length=None,
formatting_func=prepare_sample_text,
processing_class=tokenizer,
args=training_args,
Expand Down
22 changes: 11 additions & 11 deletions tests/slow/test_sft_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
def setUp(self):
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]")
self.max_seq_length = 128
self.max_length = 128
self.peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_sft_trainer_str(self, model_name, packing):
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
)

trainer = SFTTrainer(
Expand All @@ -100,7 +100,7 @@ def test_sft_trainer_transformers(self, model_name, packing):
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
)

model = AutoModelForCausalLM.from_pretrained(model_name)
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_sft_trainer_peft(self, model_name, packing):
max_steps=10,
fp16=True,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
)

model = AutoModelForCausalLM.from_pretrained(model_name)
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_sft_trainer_transformers_mp(self, model_name, packing):
max_steps=10,
fp16=True, # this is sufficient to enable amp
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
)

model = AutoModelForCausalLM.from_pretrained(model_name)
Expand Down Expand Up @@ -205,7 +205,7 @@ def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_chec
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
Expand Down Expand Up @@ -286,7 +286,7 @@ def test_sft_trainer_transformers_mp_gc_device_map(
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
Expand Down Expand Up @@ -324,7 +324,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
Expand Down Expand Up @@ -364,7 +364,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):

training_args = SFTConfig(
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
Expand Down Expand Up @@ -411,7 +411,7 @@ def test_sft_trainer_with_liger(self, model_name, packing):
per_device_train_batch_size=2,
max_steps=2,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
use_liger=True,
)

Expand Down
30 changes: 15 additions & 15 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_sft_trainer_uncorrect_data(self):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=32, # make sure there is at least 1 packed sequence
max_length=32, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
Expand All @@ -353,15 +353,15 @@ def test_sft_trainer_uncorrect_data(self):
train_dataset=self.conversational_lm_dataset["train"],
)

# Same, but with packing with `max_seq_length`
# Same, but with packing with `max_length`
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
max_steps=2,
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16, # make sure there is at least 1 packed sequence
max_length=16, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
Expand Down Expand Up @@ -396,7 +396,7 @@ def test_sft_trainer_uncorrect_data(self):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=32, # make sure there is at least 1 packed sequence
max_length=32, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
Expand Down Expand Up @@ -461,7 +461,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):
save_steps=1,
num_train_epochs=2,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
packing=True,
report_to="none",
)
Expand All @@ -485,7 +485,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):
save_steps=1,
num_train_epochs=2,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
report_to="none",
)
trainer = SFTTrainer(
Expand Down Expand Up @@ -534,7 +534,7 @@ def test_sft_trainer_with_model(self):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
packing=True,
report_to="none",
)
Expand All @@ -558,7 +558,7 @@ def test_sft_trainer_with_model(self):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
packing=True,
report_to="none",
)
Expand All @@ -583,7 +583,7 @@ def test_sft_trainer_with_model(self):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
report_to="none",
)
trainer = SFTTrainer(
Expand All @@ -606,7 +606,7 @@ def test_sft_trainer_with_model(self):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
report_to="none",
)
trainer = SFTTrainer(
Expand Down Expand Up @@ -755,7 +755,7 @@ def test_sft_trainer_infinite_with_model(self):
save_steps=1,
per_device_train_batch_size=2,
packing=True,
max_seq_length=500,
max_length=500,
report_to="none",
)
trainer = SFTTrainer(
Expand All @@ -782,7 +782,7 @@ def test_sft_trainer_infinite_with_model_epochs(self):
per_device_train_batch_size=2,
save_strategy="epoch",
packing=True,
max_seq_length=500,
max_length=500,
report_to="none",
)
trainer = SFTTrainer(
Expand Down Expand Up @@ -1088,7 +1088,7 @@ def test_sft_trainer_only_train_packing(self):
per_device_train_batch_size=2,
gradient_checkpointing=True,
packing=True,
max_seq_length=16, # make sure there is at least 1 packed sequence
max_length=16, # make sure there is at least 1 packed sequence
eval_packing=False,
report_to="none",
)
Expand All @@ -1114,7 +1114,7 @@ def test_sft_trainer_eval_packing(self):
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
max_seq_length=16, # make sure there is at least 1 packed sequence
max_length=16, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
Expand All @@ -1139,7 +1139,7 @@ def test_sft_trainer_no_packing(self):
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
max_seq_length=16, # make sure there is at least 1 packed sequence
max_length=16, # make sure there is at least 1 packed sequence
packing=False,
report_to="none",
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def test_sft(self):
tmp_dir,
dataset_text_field="dummy_text_field",
packing=True,
max_seq_length=256,
max_length=256,
dataset_num_proc=4,
dataset_batch_size=512,
neftune_noise_alpha=0.1,
Expand All @@ -379,7 +379,7 @@ def test_sft(self):
trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset)
self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field")
self.assertEqual(trainer.args.packing, True)
self.assertEqual(trainer.args.max_seq_length, 256)
self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.dataset_num_proc, 4)
self.assertEqual(trainer.args.dataset_batch_size, 512)
self.assertEqual(trainer.args.neftune_noise_alpha, 0.1)
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
):
# add remove_unused_columns=False to the dataclass args
args.remove_unused_columns = False
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)

super().__init__(
model,
Expand Down
Loading

0 comments on commit be1e340

Please sign in to comment.