From d9337c370c259b3cf2ab7a7b449c05d02ecdc5b0 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Tue, 11 Feb 2025 21:51:20 +0200 Subject: [PATCH] surface a problem in current loadHF: it realizes too late that streaming should be False. For both cases: split = str, and split=None Signed-off-by: dafnapension --- tests/library/test_loaders.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/library/test_loaders.py b/tests/library/test_loaders.py index 24245a783..8960d5269 100644 --- a/tests/library/test_loaders.py +++ b/tests/library/test_loaders.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pandas as pd +from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict from unitxt.error_utils import UnitxtError from unitxt.loaders import ( LoadCSV, @@ -232,6 +233,23 @@ def test_load_from_HF_multiple_innvocation_with_filter(self): ) # that HF dataset only has the 'test' split self.assertEqual(instance["language"], "eng") + def test_load_HF_lazily(self): + lazy_loader = LoadHF(path="ibm/finqa", streaming=True) + dataset = lazy_loader.load_dataset(split="test") + self.assertIsInstance(dataset, (Dataset, IterableDataset)) + # we just assured that load_dataset completed OK, having changed the streaming from True to False + # now we try to touch the arriving dataset, which in current main is only done by the split generator when yielding + first_example = next(iter(dataset)) + self.assertIsNotNone (first_example) + # the same goes when split=None: + dataset = lazy_loader.load_dataset(split=None) + self.assertIsInstance(dataset, (DatasetDict, IterableDatasetDict)) + # we just assured that load_dataset completed OK, having changed to streaming=False + # now we try to touch the samples in the arriving dataset, which in current main is only done by the split generator when yielding + for k in dataset.keys(): + first_example = next(iter(dataset[k])) + self.assertIsNotNone (first_example) + def test_load_from_HF_split(self): loader = LoadHF(path="sst2", split="train") ms = loader()