diff --git a/avalanche/evaluation/metrics/detection.py b/avalanche/evaluation/metrics/detection.py index 2e01df060..513dc53d6 100644 --- a/avalanche/evaluation/metrics/detection.py +++ b/avalanche/evaluation/metrics/detection.py @@ -19,6 +19,7 @@ ) from avalanche.benchmarks.utils.data import AvalancheDataset +from avalanche.benchmarks.utils.data import _FlatDataWithTransform try: from lvis import LVIS @@ -470,12 +471,12 @@ def get_detection_api_from_dataset( recursion_result = get_detection_api_from_dataset( dataset.dataset, supported_types, none_if_not_found=True ) - elif isinstance(dataset, AvalancheDataset) and len(dataset._datasets) == 1: + elif isinstance(dataset, (AvalancheDataset, _FlatDataWithTransform)) and len(dataset._datasets) == 1: recursion_result = get_detection_api_from_dataset( dataset._datasets[0], supported_types, none_if_not_found=True ) - elif isinstance(dataset, (AvalancheDataset, ConcatDataset)): - if isinstance(dataset, AvalancheDataset): + elif isinstance(dataset, (AvalancheDataset, ConcatDataset, _FlatDataWithTransform)): + if isinstance(dataset, (AvalancheDataset, _FlatDataWithTransform)): datasets_list = dataset._datasets else: datasets_list = dataset.datasets