Skip to content

Commit

Permalink
refresh loaders from just_lazy_loader
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <[email protected]>
  • Loading branch information
dafnapension committed Feb 9, 2025
1 parent 3e990f3 commit 5fc03dd
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 29 deletions.
53 changes: 26 additions & 27 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from tqdm import tqdm

from .dataclass import OptionalField
from .error_utils import UnitxtError
from .error_utils import UnitxtError, UnitxtWarning
from .fusion import FixedFusion
from .generator_utils import ReusableGenerator
from .logging_utils import get_logger
Expand Down Expand Up @@ -226,7 +226,7 @@ class LoadHF(Loader):
Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
] = None
revision: Optional[str] = None
streaming: bool = None
streaming = None
filtering_lambda: Optional[str] = None
num_proc: Optional[int] = None
requirements_list: List[str] = OptionalField(default_factory=list)
Expand Down Expand Up @@ -313,27 +313,25 @@ def load_dataset(
next(iter(dataset[k]))
break

except:
try:
current_streaming = kwargs["streaming"]
logger.info(
f"needed to swap streaming from {current_streaming} to {not current_streaming} for path {self.path}"
)
# try the opposite way of streaming
kwargs["streaming"] = not kwargs["streaming"]
dataset = hf_load_dataset(**kwargs)
if isinstance(dataset, (Dataset, IterableDataset)):
next(iter(dataset))
else:
for k in dataset.keys():
next(iter(dataset[k]))
break

except ValueError as e:
if "trust_remote_code" in str(e):
raise ValueError(
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
) from e
except Exception as e:
if e is ValueError and "trust_remote_code" in str(e):
raise ValueError(
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
) from e

current_streaming = kwargs["streaming"]
logger.info(
f"needed to swap streaming from {current_streaming} to {not current_streaming} for path {self.path}"
)
# try the opposite way of streaming
kwargs["streaming"] = not kwargs["streaming"]
dataset = hf_load_dataset(**kwargs)
if isinstance(dataset, (Dataset, IterableDataset)):
next(iter(dataset))
else:
for k in dataset.keys():
next(iter(dataset[k]))
break

if self.filtering_lambda is not None:
dataset = dataset.filter(eval(self.filtering_lambda))
Expand Down Expand Up @@ -372,6 +370,9 @@ def get_splits(self) -> List[str]:
# split names are known before the split themselves are pulled from HF,
# and we can postpone that pulling of the splits until actually demanded
return list(dataset_info.splits.keys())
UnitxtWarning(
f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.'
)
return None
except:
return None
Expand Down Expand Up @@ -914,9 +915,9 @@ class LoadFromHFSpace(LoadHF):
)
"""

path = None
space_name: str
data_files: Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
path: Optional[str] = None
revision: Optional[str] = None
use_token: Optional[bool] = None
token_env: Optional[str] = None
Expand Down Expand Up @@ -1054,8 +1055,6 @@ def _maybe_set_classification_policy(self):
def load_data(self):
self._map_wildcard_path_to_full_paths()
self.path = self._download_data()
if self.splits is None and isinstance(self.data_files, dict):
self.splits = sorted(self.data_files.keys())

return super().load_data()

Expand Down Expand Up @@ -1090,7 +1089,7 @@ class LoadFromAPI(Loader):

urls: Dict[str, str]
chunksize: int = 100000
streaming: bool = False
streaming = False
api_key_env_var: str = "SQL_API_KEY"
headers: Optional[Dict[str, Any]] = None
data_field: str = "data"
Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"filename": "src/unitxt/loaders.py",
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
"is_verified": false,
"line_number": 629,
"line_number": 630,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-02-08T13:56:45Z"
"generated_at": "2025-02-09T12:07:07Z"
}

0 comments on commit 5fc03dd

Please sign in to comment.