From 59cca3c179e83badd391ffec436a15f6d8c62454 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Fri, 5 Apr 2024 13:55:39 +0200 Subject: [PATCH] fix: different shuffling at different epochs --- dmlcloud/util/data.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/dmlcloud/util/data.py b/dmlcloud/util/data.py index 4beffea..102d625 100644 --- a/dmlcloud/util/data.py +++ b/dmlcloud/util/data.py @@ -124,6 +124,10 @@ def __init__( self.rank = rank if rank is not None else dist.get_rank(process_group) self.world_size = world_size if world_size is not None else dist.get_world_size(process_group) + self._num_iters = 0 + + def set_epoch(self, epoch: int): + self._num_iters = epoch def __iter__(self): worker_info = get_worker_info() @@ -138,15 +142,15 @@ def __iter__(self): self.ds, self.dim, self.chunk_size, - self.chunk_overlap, - self.even_shards, - self.equal_chunks, - self.shuffle, - self.seed, - rank, - world_size, - self.load, - self.load_kwargs, + chunk_overlap=self.chunk_overlap, + even_shards=self.even_shards, + equal_chunks=self.equal_chunks, + shuffle=self.shuffle, + seed=self.seed + self._num_iters, + rank=rank, + world_size=world_size, + load=self.load, + load_kwargs=self.load_kwargs, )