Skip to content

Commit

Permalink
Merge pull request #326 from RobotSail/track-total-samples
Browse files Browse the repository at this point in the history
feat: add total_samples as a field to logs being emitted
  • Loading branch information
JamesKunstle authored Nov 11, 2024
2 parents 45162d5 + 24591a7 commit ee14d11
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

# pylint: disable=no-name-in-module
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, get_scheduler
import torch
Expand Down Expand Up @@ -325,7 +326,7 @@ def train(
lr_scheduler,
accelerator: Accelerator,
tokenizer,
train_loader,
train_loader: DataLoader,
grad_accum,
metric_logger,
):
Expand Down Expand Up @@ -457,6 +458,7 @@ def train(
"total_loss": float(log_loss / num_loss_counted_tokens),
"samples_seen": samples_seen,
"gradnorm": global_grad_norm,
"total_samples": len(train_loader.dataset),
# "weight_norm": weight_norm,
}
)
Expand Down Expand Up @@ -620,6 +622,7 @@ def main(args):
"num_batches": len(train_loader),
"avg_samples_per_batch": len(dataset) / len(train_loader),
"samples_per_gpu": args.samples_per_gpu,
"total_samples": len(dataset), # emit the total number of samples
}
)

Expand Down

0 comments on commit ee14d11

Please sign in to comment.