Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta committed Jun 28, 2023
1 parent f097939 commit 2c8aba8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 17 deletions.
4 changes: 2 additions & 2 deletions avalanche/benchmarks/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def __init__(

if sum(len(x) for x in self.samplers) == 0:
raise RuntimeError(
'The never ending sampler must able to create a mini-batch'
"The never ending sampler must able to create a mini-batch"
)
else:
# termination_dataset_idx => dataset used to determine the epoch end
Expand Down Expand Up @@ -628,7 +628,7 @@ def __iter__(self):
if len(sampler) == 0:
if is_term_dataset and (not self.never_ending):
return

samplers_list[dataset_idx] = None
sampler_iterators[dataset_idx] = None
continue
Expand Down
22 changes: 7 additions & 15 deletions tests/benchmarks/test_replay_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,9 @@ def setUp(self):
np.arange(len(dataset_for_current)), size=100, replace=False
)

self.big_task_set = \
AvalancheSubset(dataset_for_current, indices_big_set)
self.small_task_set = \
AvalancheSubset(dataset_for_current, indices_small_set)
self.tiny_task_set = \
AvalancheSubset(dataset_for_current, indices_tiny_set)
self.big_task_set = AvalancheSubset(dataset_for_current, indices_big_set)
self.small_task_set = AvalancheSubset(dataset_for_current, indices_small_set)
self.tiny_task_set = AvalancheSubset(dataset_for_current, indices_tiny_set)

indices_memory = np.random.choice(
np.arange(len(dataset_for_memory)), size=2000, replace=False
Expand All @@ -44,8 +41,7 @@ def setUp(self):

self.memory_set = AvalancheSubset(dataset_for_memory, indices_memory)
self.small_memory_set = AvalancheSubset(
dataset_for_memory,
indices_memory_small
dataset_for_memory, indices_memory_small
)

self._batch_size = None
Expand All @@ -69,7 +65,7 @@ def _make_loader(self, memory_set=None, **kwargs):
def _test_batch_size(self, loader, expected_size=None):
if expected_size is None:
expected_size = self._batch_size * 2

for batch in loader:
self.assertEqual(len(batch[0]), expected_size)

Expand Down Expand Up @@ -113,17 +109,13 @@ def test_big_batch_size(self):
def test_zero_iterations_memory(self):
self._batch_size = 256
self._task_dataset = self.big_task_set
loader = self._make_loader(
memory_set=self.small_memory_set
)
loader = self._make_loader(memory_set=self.small_memory_set)
self._launch_test_suite_dropped_memory(loader)

def test_zero_iterations_current(self):
self._batch_size = 256
self._task_dataset = self.tiny_task_set
loader = self._make_loader(
memory_set=self.memory_set
)
loader = self._make_loader(memory_set=self.memory_set)
self.assertEqual(0, self._length)
self._launch_test_suite(loader)

Expand Down

0 comments on commit 2c8aba8

Please sign in to comment.