Skip to content

Commit

Permalink
Merge pull request #1437 from lrzpellegrini/ffcv_support_pt2
Browse files Browse the repository at this point in the history
Fixed issues with 0 length samplers
  • Loading branch information
AntonioCarta authored Jun 28, 2023
2 parents 427d398 + 0c53d75 commit a61ae5c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 10 deletions.
19 changes: 15 additions & 4 deletions avalanche/benchmarks/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See the accompanying LICENSE file for terms. #
# #
# Date: 01-12-2020 #
# Author(s): Antonio Carta #
# Author(s): Antonio Carta, Lorenzo Pellegrini #
# E-mail: [email protected] #
# Website: avalanche.continualai.org #
################################################################################
Expand Down Expand Up @@ -33,9 +33,8 @@
from avalanche.benchmarks.utils.data_attribute import DataAttribute
from avalanche.distributed.distributed_helper import DistributedHelper

from torch.utils.data.sampler import BatchSampler
from torch.utils.data.sampler import Sampler, BatchSampler
from torch.utils.data import ConcatDataset
from torch.utils.data.sampler import Sampler


def return_identity(x):
Expand Down Expand Up @@ -571,6 +570,11 @@ def __init__(
self.termination_dataset_idx = -1
self.termination_dataset_iterations = 10 ** 10
self.oversample_small_datasets = True

if sum(len(x) for x in self.samplers) == 0:
raise RuntimeError(
'The never ending sampler must able to create a mini-batch'
)
else:
# termination_dataset_idx => dataset used to determine the epoch end
self.termination_dataset_idx = termination_dataset_idx
Expand Down Expand Up @@ -621,6 +625,14 @@ def __iter__(self):
if sampler is None:
continue

if len(sampler) == 0:
if is_term_dataset and (not self.never_ending):
return

samplers_list[dataset_idx] = None
sampler_iterators[dataset_idx] = None
continue

should_stop_if_ended = (
is_term_dataset or
not self.oversample_small_datasets
Expand Down Expand Up @@ -672,7 +684,6 @@ def _next_batch(

# Re-create the iterator
# This time, do not catch StopIteration

if isinstance(sampler, BatchSampler):
if isinstance(sampler.sampler, DistributedSampler):
sampler.sampler.set_epoch(
Expand Down
55 changes: 49 additions & 6 deletions tests/benchmarks/test_replay_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,49 @@ def setUp(self):
dataset_for_current = scenario.train_stream[1].dataset
dataset_for_memory = scenario.train_stream[0].dataset

indices_big_set = np.random.choice(
np.arange(len(dataset_for_current)), size=10000, replace=False
)

indices_small_set = np.random.choice(
np.arange(len(dataset_for_current)), size=1000, replace=False
)

indices_big_set = np.random.choice(
np.arange(len(dataset_for_current)), size=10000, replace=False
indices_tiny_set = np.random.choice(
np.arange(len(dataset_for_current)), size=100, replace=False
)

self.big_task_set = \
AvalancheSubset(dataset_for_current, indices_big_set)
self.small_task_set = \
AvalancheSubset(dataset_for_current, indices_small_set)
self.tiny_task_set = \
AvalancheSubset(dataset_for_current, indices_tiny_set)

indices_memory = np.random.choice(
np.arange(len(dataset_for_memory)), size=2000, replace=False
)

indices_memory_small = np.random.choice(
np.arange(len(dataset_for_memory)), size=100, replace=False
)

self.memory_set = AvalancheSubset(dataset_for_memory, indices_memory)
self.small_memory_set = AvalancheSubset(
dataset_for_memory,
indices_memory_small
)

self._batch_size = None
self._task_dataset = None

def _make_loader(self, **kwargs):
def _make_loader(self, memory_set=None, **kwargs):
if memory_set is None:
memory_set = self.memory_set

loader = ReplayDataLoader(
self._task_dataset,
self.memory_set,
memory_set,
batch_size=self._batch_size,
batch_size_mem=self._batch_size,
oversample_small_tasks=True,
Expand All @@ -48,9 +66,12 @@ def _make_loader(self, **kwargs):
)
return loader

def _test_batch_size(self, loader):
def _test_batch_size(self, loader, expected_size=None):
if expected_size is None:
expected_size = self._batch_size * 2

for batch in loader:
self.assertEqual(len(batch[0]), self._batch_size * 2)
self.assertEqual(len(batch[0]), expected_size)

def _test_length(self, loader):
self.assertEqual(len(loader), self._length)
Expand All @@ -66,6 +87,11 @@ def _launch_test_suite(self, loader):
self._test_length(loader)
self._test_actual_length(loader)

def _launch_test_suite_dropped_memory(self, loader):
self._test_batch_size(loader, expected_size=self._batch_size)
self._test_length(loader)
self._test_actual_length(loader)

def test_bigger_memory(self):
self._batch_size = 64
self._task_dataset = self.small_task_set
Expand All @@ -84,6 +110,23 @@ def test_big_batch_size(self):
loader = self._make_loader()
self._launch_test_suite(loader)

def test_zero_iterations_memory(self):
self._batch_size = 256
self._task_dataset = self.big_task_set
loader = self._make_loader(
memory_set=self.small_memory_set
)
self._launch_test_suite_dropped_memory(loader)

def test_zero_iterations_current(self):
self._batch_size = 256
self._task_dataset = self.tiny_task_set
loader = self._make_loader(
memory_set=self.memory_set
)
self.assertEqual(0, self._length)
self._launch_test_suite(loader)

def test_small_batch_size(self):
self._batch_size = 5
self._task_dataset = self.big_task_set
Expand Down

0 comments on commit a61ae5c

Please sign in to comment.