diff --git a/commands/run_sft.sh b/commands/run_sft.sh index bdea77fcb6..b7beaaf7fd 100644 --- a/commands/run_sft.sh +++ b/commands/run_sft.sh @@ -42,7 +42,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \ --output_dir $OUTPUT_DIR \ --max_steps $MAX_STEPS \ --per_device_train_batch_size $BATCH_SIZE \ - --max_seq_length $SEQ_LEN \ + --max_length $SEQ_LEN \ $EXTRA_TRAINING_ARGS """ diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index 6c05490616..cc335156e6 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -44,7 +44,7 @@ training_args = DPOConfig(..., max_completion_length=...) -SFT truncation is applied to the input sequence via the `max_seq_length` parameter. +SFT truncation is applied to the input sequence via the `max_length` parameter.
Truncation input ids @@ -55,7 +55,7 @@ To set the truncation parameter, use the following code snippet: ```python from trl import SFTConfig -training_args = SFTConfig(..., max_seq_length=...) +training_args = SFTConfig(..., max_length=...) ``` @@ -85,7 +85,7 @@ Packing eliminates padding, preserves all sequence information, and allows for f ```python from trl import SFTConfig -training_args = SFTConfig(..., packing=True, max_seq_length=512) +training_args = SFTConfig(..., packing=True, max_length=512) ``` diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md index ab3e9e1cc5..5c30b744fa 100644 --- a/docs/source/sft_trainer.md +++ b/docs/source/sft_trainer.md @@ -19,7 +19,7 @@ from trl import SFTConfig, SFTTrainer dataset = load_dataset("stanfordnlp/imdb", split="train") training_args = SFTConfig( - max_seq_length=512, + max_length=512, output_dir="/tmp", ) trainer = SFTTrainer( @@ -29,7 +29,7 @@ trainer = SFTTrainer( ) trainer.train() ``` -Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`. +Make sure to pass the correct value for `max_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`. You can also construct a model outside of the trainer and pass it as follows: @@ -550,12 +550,12 @@ import torch from trl import SFTConfig, SFTTrainer from unsloth import FastLanguageModel -max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number +max_length = 2048 # Supports automatic RoPE Scaling, so choose any number # Load model model, tokenizer = FastLanguageModel.from_pretrained( model_name="unsloth/mistral-7b", - max_seq_length=max_seq_length, + max_seq_length=max_length, dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf @@ -581,7 +581,7 @@ model = FastLanguageModel.get_peft_model( random_state=3407, ) -training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length) +training_args = SFTConfig(output_dir="./output", max_length=max_length) trainer = SFTTrainer( model=model, @@ -624,7 +624,7 @@ To learn more about Liger-Kernel, visit their [official repository](https://gith Pay attention to the following best practices when training a model with that trainer: -- [`SFTTrainer`] always truncates by default the sequences to the `max_seq_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training. +- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training. - For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it. - For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it. - If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method. diff --git a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py index 3ae1e82c2a..1f4611a3e8 100644 --- a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py @@ -185,7 +185,7 @@ def create_datasets(tokenizer, args, seed=None): train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=peft_config, - max_seq_length=None, + max_length=None, formatting_func=prepare_sample_text, processing_class=tokenizer, args=training_args, diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py index 74811d092c..8a772a48aa 100644 --- a/tests/slow/test_sft_slow.py +++ b/tests/slow/test_sft_slow.py @@ -46,7 +46,7 @@ class SFTTrainerSlowTester(unittest.TestCase): def setUp(self): self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]") self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]") - self.max_seq_length = 128 + self.max_length = 128 self.peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.1, @@ -74,7 +74,7 @@ def test_sft_trainer_str(self, model_name, packing): per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) trainer = SFTTrainer( @@ -100,7 +100,7 @@ def test_sft_trainer_transformers(self, model_name, packing): per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) model = AutoModelForCausalLM.from_pretrained(model_name) @@ -135,7 +135,7 @@ def test_sft_trainer_peft(self, model_name, packing): max_steps=10, fp16=True, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) model = AutoModelForCausalLM.from_pretrained(model_name) @@ -172,7 +172,7 @@ def test_sft_trainer_transformers_mp(self, model_name, packing): max_steps=10, fp16=True, # this is sufficient to enable amp packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, ) model = AutoModelForCausalLM.from_pretrained(model_name) @@ -205,7 +205,7 @@ def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_chec per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -242,7 +242,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -286,7 +286,7 @@ def test_sft_trainer_transformers_mp_gc_device_map( per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -324,7 +324,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr per_device_train_batch_size=2, max_steps=10, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, @@ -364,7 +364,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing): training_args = SFTConfig( packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, output_dir=tmp_dir, logging_strategy="no", report_to="none", @@ -411,7 +411,7 @@ def test_sft_trainer_with_liger(self, model_name, packing): per_device_train_batch_size=2, max_steps=2, packing=packing, - max_seq_length=self.max_seq_length, + max_length=self.max_length, use_liger=True, ) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 14d235585a..1a26378f3f 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -326,7 +326,7 @@ def test_sft_trainer_uncorrect_data(self): eval_steps=1, save_steps=1, per_device_train_batch_size=2, - max_seq_length=32, # make sure there is at least 1 packed sequence + max_length=32, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -353,7 +353,7 @@ def test_sft_trainer_uncorrect_data(self): train_dataset=self.conversational_lm_dataset["train"], ) - # Same, but with packing with `max_seq_length` + # Same, but with packing with `max_length` training_args = SFTConfig( output_dir=tmp_dir, dataloader_drop_last=True, @@ -361,7 +361,7 @@ def test_sft_trainer_uncorrect_data(self): eval_steps=1, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -396,7 +396,7 @@ def test_sft_trainer_uncorrect_data(self): eval_steps=1, save_steps=1, per_device_train_batch_size=2, - max_seq_length=32, # make sure there is at least 1 packed sequence + max_length=32, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -461,7 +461,7 @@ def test_sft_trainer_with_model_num_train_epochs(self): save_steps=1, num_train_epochs=2, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, packing=True, report_to="none", ) @@ -485,7 +485,7 @@ def test_sft_trainer_with_model_num_train_epochs(self): save_steps=1, num_train_epochs=2, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, report_to="none", ) trainer = SFTTrainer( @@ -534,7 +534,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, packing=True, report_to="none", ) @@ -558,7 +558,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, packing=True, report_to="none", ) @@ -583,7 +583,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, report_to="none", ) trainer = SFTTrainer( @@ -606,7 +606,7 @@ def test_sft_trainer_with_model(self): max_steps=2, save_steps=1, per_device_train_batch_size=2, - max_seq_length=16, + max_length=16, report_to="none", ) trainer = SFTTrainer( @@ -755,7 +755,7 @@ def test_sft_trainer_infinite_with_model(self): save_steps=1, per_device_train_batch_size=2, packing=True, - max_seq_length=500, + max_length=500, report_to="none", ) trainer = SFTTrainer( @@ -782,7 +782,7 @@ def test_sft_trainer_infinite_with_model_epochs(self): per_device_train_batch_size=2, save_strategy="epoch", packing=True, - max_seq_length=500, + max_length=500, report_to="none", ) trainer = SFTTrainer( @@ -1088,7 +1088,7 @@ def test_sft_trainer_only_train_packing(self): per_device_train_batch_size=2, gradient_checkpointing=True, packing=True, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence eval_packing=False, report_to="none", ) @@ -1114,7 +1114,7 @@ def test_sft_trainer_eval_packing(self): save_steps=2, per_device_train_batch_size=2, gradient_checkpointing=True, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -1139,7 +1139,7 @@ def test_sft_trainer_no_packing(self): save_steps=2, per_device_train_batch_size=2, gradient_checkpointing=True, - max_seq_length=16, # make sure there is at least 1 packed sequence + max_length=16, # make sure there is at least 1 packed sequence packing=False, report_to="none", ) diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index 251b1f5a96..406eba4f86 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -368,7 +368,7 @@ def test_sft(self): tmp_dir, dataset_text_field="dummy_text_field", packing=True, - max_seq_length=256, + max_length=256, dataset_num_proc=4, dataset_batch_size=512, neftune_noise_alpha=0.1, @@ -379,7 +379,7 @@ def test_sft(self): trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset) self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field") self.assertEqual(trainer.args.packing, True) - self.assertEqual(trainer.args.max_seq_length, 256) + self.assertEqual(trainer.args.max_length, 256) self.assertEqual(trainer.args.dataset_num_proc, 4) self.assertEqual(trainer.args.dataset_batch_size, 512) self.assertEqual(trainer.args.neftune_noise_alpha, 0.1) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 5e76f30ed2..7cfad453f7 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -87,7 +87,7 @@ def __init__( ): # add remove_unused_columns=False to the dataclass args args.remove_unused_columns = False - data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length) + data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) super().__init__( model, diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index ad0e936c18..23b617dfe2 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -49,13 +49,11 @@ class SFTConfig(TrainingArguments): `skip_prepare_dataset`. dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. - max_seq_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the - right. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. packing (`bool`, *optional*, defaults to `False`): - Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence - length. + Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define sequence length. eval_packing (`bool` or `None`, *optional*, defaults to `None`): Whether to pack the eval dataset. If `None`, uses the same value as `packing`. @@ -95,19 +93,19 @@ class SFTConfig(TrainingArguments): default=None, metadata={"help": "Number of processes to use for processing the dataset."}, ) - max_seq_length: Optional[int] = field( + max_length: Optional[int] = field( default=1024, metadata={ - "help": "Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated " - "from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from" + "the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " "sequence length." }, ) packing: bool = field( default=False, metadata={ - "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to " - "define sequence length." + "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define " + "sequence length." }, ) eval_packing: Optional[bool] = field( @@ -132,13 +130,17 @@ class SFTConfig(TrainingArguments): num_of_sequences: int = field( default=None, metadata={ - "help": "Deprecated. Use `max_seq_length` instead, which specifies the maximum length of the tokenized " + "help": "Deprecated. Use `max_length` instead, which specifies the maximum length of the tokenized " "sequence, unlike `num_of_sequences`, which referred to string sequences." }, ) chars_per_token: float = field( default=None, - metadata={"help": "Deprecated. If you want to customize the packing length, use `max_seq_length`."}, + metadata={"help": "Deprecated. If you want to customize the packing length, use `max_length`."}, + ) + max_seq_length: Optional[int] = field( + default=None, + metadata={"help": "Deprecated. Use `max_length` instead."}, ) def __post_init__(self): @@ -153,7 +155,7 @@ def __post_init__(self): if self.num_of_sequences is not None: warnings.warn( - "`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_seq_length` instead, " + "`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_length` instead, " "which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which r" "eferred to string sequences.", DeprecationWarning, @@ -162,6 +164,12 @@ def __post_init__(self): if self.chars_per_token is not None: warnings.warn( "`chars_per_token` is deprecated and will be remove in version 0.18.0. If you want to customize the " - "packing length, use `max_seq_length`.", + "packing length, use `max_length`.", + DeprecationWarning, + ) + + if self.max_seq_length is not None: + warnings.warn( + "`max_seq_length` is deprecated and will be remove in version 0.20.0. Use `max_length` instead.", DeprecationWarning, ) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index e4708eb7c7..b0104f4b53 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -434,17 +434,17 @@ def tokenize(ex): # Pack or truncate if packing: - if args.max_seq_length is None: - raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.") + if args.max_length is None: + raise ValueError("When packing is enabled, `max_length` can't be `None`.") if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Packing {dataset_name} dataset" dataset = dataset.select_columns("input_ids") dataset = dataset.map( - pack_examples, batched=True, fn_kwargs={"seq_length": args.max_seq_length}, **map_kwargs + pack_examples, batched=True, fn_kwargs={"seq_length": args.max_length}, **map_kwargs ) - elif args.max_seq_length is not None: + elif args.max_length is not None: dataset = dataset.map( - lambda ex: {key: ex[key][: args.max_seq_length] for key in ["input_ids", "attention_mask"]}, + lambda ex: {key: ex[key][: args.max_length] for key in ["input_ids", "attention_mask"]}, **map_kwargs, ) # For Liger kernel, ensure only input_ids is present diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 7a20645535..853ba1f3ca 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -140,7 +140,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d warnings.warn( f"Could not find response key `{self.response_template}` in the following instance: " f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " - "calculation. Note, if this happens often, consider increasing the `max_seq_length`.", + "calculation. Note, if this happens often, consider increasing the `max_length`.", UserWarning, ) batch["labels"][i, :] = self.ignore_index @@ -167,7 +167,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d warnings.warn( f"Could not find response key `{self.response_template}` in the following instance: " f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " - "calculation. Note, if this happens often, consider increasing the `max_seq_length`.", + "calculation. Note, if this happens often, consider increasing the `max_length`.", UserWarning, ) batch["labels"][i, :] = self.ignore_index @@ -182,7 +182,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d warnings.warn( f"Could not find instruction key `{self.instruction_template}` in the following instance: " f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " - "calculation. Note, if this happens often, consider increasing the `max_seq_length`.", + "calculation. Note, if this happens often, consider increasing the `max_length`.", UserWarning, ) batch["labels"][i, :] = self.ignore_index