Skip to content

Commit

Permalink
[docs] fix outdated example code in trainer.md (#36066)
Browse files Browse the repository at this point in the history
fix bugs
  • Loading branch information
faaany authored Feb 6, 2025
1 parent 4563ba2 commit 6246c03
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ from torch import nn
from transformers import Trainer

class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
labels = inputs.pop("labels")
# forward pass
outputs = model(**inputs)
Expand All @@ -156,9 +156,7 @@ class EarlyStoppingCallback(TrainerCallback):

def on_step_end(self, args, state, control, **kwargs):
if state.global_step >= self.num_steps:
return {"should_training_stop": True}
else:
return {}
control.should_training_stop = True
```

Then pass it to the [`Trainer`]'s `callback` parameter.
Expand Down Expand Up @@ -737,7 +735,7 @@ accelerate launch --num_processes=2 \
--fsdp_transformer_layer_cls_to_wrap="BertLayer" \
--fsdp_sharding_strategy=1 \
--fsdp_state_dict_type=FULL_STATE_DICT \
./examples/pytorch/text-classification/run_glue.py
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name $TASK_NAME \
--do_train \
Expand Down

0 comments on commit 6246c03

Please sign in to comment.