Skip to content

Commit

Permalink
Fix issue ContinualAI#774
Browse files Browse the repository at this point in the history
  • Loading branch information
lrzpellegrini committed Jul 19, 2023
1 parent 435b40d commit 960d013
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
14 changes: 14 additions & 0 deletions avalanche/benchmarks/classic/core50.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,20 @@ def CORe50(
eval_transform=eval_transform,
)

if scenario == "nc":
n_classes_per_exp = []
classes_order = []
for exp in benchmark_obj.train_stream:
exp_dataset = exp.dataset
unique_targets = list(
sorted(set(int(x) for x in exp_dataset.targets)) # type: ignore
)
n_classes_per_exp.append(len(unique_targets))
classes_order.extend(unique_targets)
setattr(benchmark_obj, "n_classes_per_exp", n_classes_per_exp)
setattr(benchmark_obj, "classes_order", classes_order)
setattr(benchmark_obj, "n_classes", 50 if object_lvl else 10)

return benchmark_obj


Expand Down
5 changes: 5 additions & 0 deletions tests/test_core50.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def test_core50_nc_benchmark(self):
classes_in_test = benchmark_instance.classes_in_experience["test"][0]
self.assertSetEqual(set(range(50)), set(classes_in_test))

# Regression tests for issue #774
self.assertSequenceEqual([10] + ([5] * 8), benchmark_instance.n_classes_per_exp)
self.assertSetEqual(set(range(50)), set(benchmark_instance.classes_order))
self.assertEqual(50, len(benchmark_instance.classes_order))


if __name__ == "__main__":
unittest.main()

0 comments on commit 960d013

Please sign in to comment.