Skip to content

Commit

Permalink
Hf data preprocessing: Truncate dataset if max_samples is provided (#…
Browse files Browse the repository at this point in the history
…1561)

## Describe your changes
- For huggingface data preprocessing (except text-gen which has it's own
logic), truncate the data before tokenization if `max_samples` is
provided.
- There is no need to tokenizer and process the whole dataset if only a
subset is going to be used. This is useful for large datasets where the
tokenized data might be too large to fit in memory.

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
jambayk authored Jan 22, 2025
1 parent be1278c commit cdb8693
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions olive/data/component/pre_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,23 @@ def pre_process(dataset, **kwargs):
return dataset


def _huggingface_pre_process_helper(dataset, map_func, **kwargs):
def _huggingface_pre_process_helper(dataset, map_func, max_samples, **kwargs):
"""Apply a map function to the dataset.
Args:
dataset (object): Data to be pre-processed.
map_func (function): Function to be applied to the dataset.
max_samples (int): Max number of samples to use.
**kwargs: Additional arguments.
Returns:
object: Pre-processed data.
"""
if max_samples is not None:
# select the data beforehand to avoid tokenizing the whole dataset
dataset = dataset.select(range(min(len(dataset), max_samples)))

# output type is list
tokenized_datasets = dataset.map(
map_func,
Expand Down Expand Up @@ -96,7 +101,7 @@ def _tokenizer_and_align_labels(examples):
if model_hf_config and model_hf_config.label2id:
dataset = dataset.align_labels_with_mapping(model_hf_config.label2id, label_col)

tokenized_datasets = _huggingface_pre_process_helper(dataset, _tokenizer_and_align_labels, **kwargs)
tokenized_datasets = _huggingface_pre_process_helper(dataset, _tokenizer_and_align_labels, max_samples, **kwargs)
# label_col is "label" since we added label_col as "label" to tokenized_inputs
return BaseDataset(tokenized_datasets, label_col="label", max_samples=max_samples)

Expand Down Expand Up @@ -146,7 +151,7 @@ def _tokenizer_and_align_labels(examples):
tokenized_inputs["label"] = new_labels
return tokenized_inputs

tokenized_datasets = _huggingface_pre_process_helper(dataset, _tokenizer_and_align_labels, **kwargs)
tokenized_datasets = _huggingface_pre_process_helper(dataset, _tokenizer_and_align_labels, max_samples, **kwargs)
return BaseDataset(tokenized_datasets, label_col="label", max_samples=max_samples)


Expand Down Expand Up @@ -238,5 +243,5 @@ def _tokenizer_and_align_labels(examples):

return tokenized_inputs

tokenized_datasets = _huggingface_pre_process_helper(dataset, _tokenizer_and_align_labels, **kwargs)
tokenized_datasets = _huggingface_pre_process_helper(dataset, _tokenizer_and_align_labels, max_samples, **kwargs)
return BaseDataset(tokenized_datasets, label_col="label", max_samples=max_samples)

0 comments on commit cdb8693

Please sign in to comment.