Skip to content

Commit

Permalink
feat: more comprehensive dataset logs
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Oct 10, 2023
1 parent 73a7b87 commit af5ac42
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions dmlcloud/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,20 +233,36 @@ def setup_dataset(self):
logging.info(f'Dataset creation took {(datetime.now() - ts).total_seconds():.1f}s')

if hasattr(self.train_dl, 'dataset') and hasattr(self.train_dl.dataset, '__len__'):
logging.info(f'Train dataset size: {len(self.train_dl.dataset)}')
if hasattr(self.val_dl, 'dataset') and hasattr(self.val_dl.dataset, '__len__'):
logging.info(f' Val dataset size: {len(self.val_dl.dataset)}')

train_samples = f'{len(self.train_dl.dataset)}'
else:
train_samples = 'N/A'
train_sizes = hvd.allgather(torch.tensor([len(self.train_dl)]), name='train_dataset_size')
train_sizes = [t.item() for t in train_sizes]
msg = 'Train dataset:'
msg += f'\n\t* Batches: {train_sizes[0]}'
msg += f'\n\t* Batches (total): {sum(train_sizes)}'
msg += f'\n\t* Samples (calculated): {sum(train_sizes) * self.cfg.batch_size}'
msg += f'\n\t* Samples (raw): {train_samples}'
logging.info(msg)
if len(set(train_sizes)) > 1 and self.is_root:
logging.warning(f'Uneven train dataset batches: {train_sizes}')
logging.warning(f'!!! Uneven train dataset batches: {train_sizes}')

if self.val_dl is not None:
if hasattr(self.val_dl, 'dataset') and hasattr(self.val_dl.dataset, '__len__'):
val_samples = f'{len(self.val_dl.dataset)}'
else:
val_samples = 'N/A'

val_sizes = hvd.allgather(torch.tensor([len(self.val_dl)]), name='val_dataset_size')
val_sizes = [t.item() for t in val_sizes]
msg = 'Train dataset:'
msg += f'\n\t* Batches: {val_sizes[0]}'
msg += f'\n\t* Batches (total): {sum(val_sizes)}'
msg += f'\n\t* Samples (calculated): {sum(val_sizes) * self.cfg.batch_size}'
msg += f'\n\t* Samples (raw): {val_samples}'
logging.info(msg)
if len(set(val_sizes)) > 1 and self.is_root:
logging.warning(f'Uneven val dataset batches: {val_sizes}')
logging.warning(f'!!! Uneven val dataset batches: {val_sizes}')

log_delimiter()

Expand Down

0 comments on commit af5ac42

Please sign in to comment.