Skip to content

Commit

Permalink
Merge pull request #6 from IBM/v.0.0.9
Browse files Browse the repository at this point in the history
V.0.0.9
  • Loading branch information
RaulFD-creator authored May 17, 2024
2 parents e687c1c + 7798cf4 commit 9190b17
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
19 changes: 14 additions & 5 deletions hestia/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(self, data: pd.DataFrame):
self.sim_df = None
self.partitions = None
print('Initialising Hestia Dataset Generator')
print(f'Number of items in data: {len(self.data)}')
print(f'Number of items in data: {len(self.data):,}')

def from_precalculated(self, data_path: str):
"""Load partition indexes if they have already being calculated.
Expand Down Expand Up @@ -260,10 +260,11 @@ def calculate_partitions(
:type similarity_args: SimilarityArguments, optional
:raises ValueError: Partitioning algorithm not supported.
"""
print('Calculating partitions...')
self.partitions = {}
if self.sim_df is None:
self.calculate_similarity(similarity_args)
print('Calculating partitions...')

if partition_algorithm == 'ccpart':
partition_algorithm = ccpart
elif partition_algorithm == 'graph_part':
Expand Down Expand Up @@ -311,15 +312,23 @@ def generate_datasets(self, dataset_type: str, threshold: float) -> dict:
if dataset_type == 'huggingface' or dataset_type == 'hf':
try:
import datasets
import pyarrow as pa
except ImportError:
raise ImportError(
f"This dataset_type: {dataset_type} requires `datasets` " +
"to be installed. Install using: `pip install datasets`"
)
for key, value in self.partitions[threshold].items():
ds[key] = datasets.Dataset.from_pandas(
self.data.iloc[value].reset_index()
)
try:
ds[key] = datasets.Dataset.from_pandas(
self.data.iloc[value].reset_index()
)
except pa.ArrowInvalid:
ds[key] = datasets.Dataset.from_dict({
column: [row[column] for idx, row in
self.data.iloc[value].iterrows()]
for column in self.data.columns
})
return ds
elif dataset_type == 'pytorch' or dataset_type == 'torch':
try:
Expand Down
2 changes: 1 addition & 1 deletion hestia/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ def __getitem__(self, index):
return features, label

def __len__(self):
return len(self.dataframe)
return len(self.dataframe)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@
test_suite='tests',
tests_require=test_requirements,
url='https://github.com/IBM/Hestia-OOD',
version='0.0.8',
version='0.0.9',
zip_safe=False,
)

0 comments on commit 9190b17

Please sign in to comment.