diff --git a/hestia/dataset_generator.py b/hestia/dataset_generator.py index c471539..8f90c8c 100644 --- a/hestia/dataset_generator.py +++ b/hestia/dataset_generator.py @@ -158,7 +158,7 @@ def __init__(self, data: pd.DataFrame): self.sim_df = None self.partitions = None print('Initialising Hestia Dataset Generator') - print(f'Number of items in data: {len(self.data)}') + print(f'Number of items in data: {len(self.data):,}') def from_precalculated(self, data_path: str): """Load partition indexes if they have already being calculated. @@ -260,10 +260,11 @@ def calculate_partitions( :type similarity_args: SimilarityArguments, optional :raises ValueError: Partitioning algorithm not supported. """ - print('Calculating partitions...') self.partitions = {} if self.sim_df is None: self.calculate_similarity(similarity_args) + print('Calculating partitions...') + if partition_algorithm == 'ccpart': partition_algorithm = ccpart elif partition_algorithm == 'graph_part': @@ -311,15 +312,23 @@ def generate_datasets(self, dataset_type: str, threshold: float) -> dict: if dataset_type == 'huggingface' or dataset_type == 'hf': try: import datasets + import pyarrow as pa except ImportError: raise ImportError( f"This dataset_type: {dataset_type} requires `datasets` " + "to be installed. Install using: `pip install datasets`" ) for key, value in self.partitions[threshold].items(): - ds[key] = datasets.Dataset.from_pandas( - self.data.iloc[value].reset_index() - ) + try: + ds[key] = datasets.Dataset.from_pandas( + self.data.iloc[value].reset_index() + ) + except pa.ArrowInvalid: + ds[key] = datasets.Dataset.from_dict({ + column: [row[column] for idx, row in + self.data.iloc[value].iterrows()] + for column in self.data.columns + }) return ds elif dataset_type == 'pytorch' or dataset_type == 'torch': try: diff --git a/hestia/utils/dataset_utils.py b/hestia/utils/dataset_utils.py index e15f255..d1fcace 100644 --- a/hestia/utils/dataset_utils.py +++ b/hestia/utils/dataset_utils.py @@ -13,4 +13,4 @@ def __getitem__(self, index): return features, label def __len__(self): - return len(self.dataframe) \ No newline at end of file + return len(self.dataframe) diff --git a/setup.py b/setup.py index cbe51f7..c70c94a 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,6 @@ test_suite='tests', tests_require=test_requirements, url='https://github.com/IBM/Hestia-OOD', - version='0.0.8', + version='0.0.9', zip_safe=False, )