From 2321952cc97d7974c38852be2f8022a062a1d1e3 Mon Sep 17 00:00:00 2001 From: hariharandev1 Date: Fri, 30 Aug 2024 17:53:57 -0700 Subject: [PATCH] ensure the sampler do not goes past the file in the last rank. --- dlio_benchmark/data_loader/torch_data_loader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dlio_benchmark/data_loader/torch_data_loader.py b/dlio_benchmark/data_loader/torch_data_loader.py index 989dae6a..80212fff 100644 --- a/dlio_benchmark/data_loader/torch_data_loader.py +++ b/dlio_benchmark/data_loader/torch_data_loader.py @@ -92,8 +92,10 @@ def __init__(self, rank, size, num_samples, epochs): self.epochs = epochs samples_per_proc = int(math.ceil(num_samples/size)) start_sample = self.rank * samples_per_proc - end_sample = (self.rank + 1) * samples_per_proc - self.indices = list(range(start_sample, end_sample)) + end_sample = (self.rank + 1) * samples_per_proc - 1 + if end_sample > num_samples - 1: + end_sample = num_samples - 1 + self.indices = list(range(start_sample, end_sample + 1)) def __len__(self):