Skip to content

Commit

Permalink
Allow creation of datasets from huggingface. Run cleaning and sample
Browse files Browse the repository at this point in the history
constraints on constructed datasets.
  • Loading branch information
harrykeightley committed Oct 16, 2023
1 parent a097692 commit 12e2b41
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 12 deletions.
96 changes: 95 additions & 1 deletion elpis/datasets/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@
from loguru import logger
from transformers import AutoFeatureExtractor, AutoTokenizer

from elpis.datasets.clean_text import clean_text
from elpis.models.job import Job

LOGGING_TRANSCRIPT_SAMPLE = 2


def create_dataset(
def create_dataset(job: Job) -> DatasetDict:
if Path(job.data_args.dataset_name_or_path).is_dir():
return create_local_dataset(job)

return create_hf_dataset(job)


def create_local_dataset(
job: Job,
test_size: float = 0.2,
) -> DatasetDict:
Expand Down Expand Up @@ -65,6 +73,45 @@ def resolve_audio_path(row: Dict[str, Any]) -> Dict[str, Any]:
return dataset


def create_hf_dataset(job: Job) -> DatasetDict:
dataset = DatasetDict()
data_args = job.data_args

if job.training_args.do_train:
dataset["train"] = load_dataset(
data_args.dataset_name_or_path,
data_args.dataset_config_name,
split=data_args.train_split_name,
token=data_args.token,
)

if data_args.audio_column_name not in dataset["train"].column_names:
raise ValueError(
f"audio_column_name '{data_args.audio_column_name}' not found"
f" in dataset '{data_args.dataset_name_or_path}'."
" Make sure to set `audio_column_name` to the correct audio column - one of"
f" {', '.join(dataset['train'].column_names)}."
)

if data_args.text_column_name not in dataset["train"].column_names:
raise ValueError(
f"text_column_name {data_args.text_column_name} not found"
f" in dataset '{data_args.dataset_name_or_path}'. "
"Make sure to set `text_column_name` to the correct text column - one of "
f"{', '.join(dataset['train'].column_names)}."
)

if job.training_args.do_eval:
dataset["eval"] = load_dataset(
data_args.dataset_name_or_path,
data_args.dataset_config_name,
split=data_args.eval_split_name,
token=data_args.token,
)

return dataset


def prepare_dataset(
job: Job,
tokenizer: AutoTokenizer,
Expand All @@ -77,6 +124,8 @@ def prepare_dataset(
dataset: The dataset on which to apply the preprocessing
processor: The processor to apply over the dataset
"""
dataset = clean_dataset(job, dataset)
dataset = constrain_to_max_samples(job, dataset)

# Load the audio data and resample if necessary.
dataset = dataset.cast_column(
Expand Down Expand Up @@ -131,3 +180,48 @@ def is_audio_in_length_range(length: int):
logger.info(f"Test encoding labels: {dataset['train'][0]['labels']}")

return dataset


def constrain_to_max_samples(job: Job, dataset: DatasetDict) -> DatasetDict:
max_train_samples = job.data_args.max_train_samples
max_eval_samples = job.data_args.max_eval_samples

if job.training_args.do_train and max_train_samples is not None:
dataset["train"] = dataset["train"].select(range(max_train_samples))

if job.training_args.do_eval and max_eval_samples is not None:
dataset["eval"] = dataset["eval"].select(range(max_eval_samples))

return dataset


def clean_dataset(job: Job, dataset: DatasetDict) -> DatasetDict:
if not job.data_args.do_clean:
return dataset

text_column = job.data_args.text_column_name

def clean(batch: Dict[str, Any]):
characters_to_remove = "".join(job.data_args.chars_to_remove or [])
characters_to_explode = "".join(job.data_args.chars_to_explode or [])

batch[text_column] = (
clean_text(
batch[text_column],
words_to_remove=job.data_args.words_to_remove,
characters_to_remove=characters_to_remove,
characters_to_explode=characters_to_explode,
to_lower=job.data_args.do_lower_case or True,
)
+ " " # Note: not sure why this is necessary, but saw in hf docs.
)

return batch

with job.training_args.main_process_first(desc="Dataset cleaning."):
dataset = dataset.map(
clean,
desc="Cleaning the dataset and standardizing case.",
)

return dataset
12 changes: 11 additions & 1 deletion elpis/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from elpis.models.annotation import Annotation
from elpis.models.elan_options import ElanOptions, ElanTierSelector
from elpis.models.job import DataArguments, Job, ModelArguments
from elpis.models.vocab import VOCAB_FILE, Vocab

__all__ = ["Annotation", "ElanOptions", "ElanTierSelector", "Vocab", "VOCAB_FILE"]
__all__ = [
"Annotation",
"ElanOptions",
"ElanTierSelector",
"Job",
"Vocab",
"VOCAB_FILE",
"DataArguments",
"ModelArguments",
]
32 changes: 25 additions & 7 deletions elpis/models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,33 @@ class DataArguments:
)
},
)
chars_to_ignore: Optional[List[str]] = list_field(
do_clean: bool = field(
default=True,
metadata={"help": "True if the dataset should be cleaned before use."},
)
words_to_remove: Optional[List[str]] = list_field(
default=[],
metadata={
"help": "A list of words to remove from the transcripts during dataset cleaning."
},
)
chars_to_remove: Optional[List[str]] = list_field(
default=[],
metadata={
"help": "A list of characters to remove from the transcripts during dataset cleaning."
},
)
chars_to_explode: Optional[List[str]] = list_field(
default=[],
metadata={
"help": "A list of characters to replace with spaces in the transcripts during dataset cleaning."
},
)
do_lower_case: Optional[bool] = field(
default=None,
metadata={"help": "A list of characters to remove from the transcripts."},
metadata={"help": "Whether the target text should be lower cased."},
)
eval_metrics: List[str] = list_field(
eval_metrics: List[str] = list_field( # type: ignore
default=DEFAULT_METRICS,
metadata={
"help": "A list of metrics the model should be evaluated on. E.g. `('wer', 'cer')`"
Expand Down Expand Up @@ -270,10 +292,6 @@ class DataArguments:
)
},
)
do_lower_case: Optional[bool] = field(
default=None,
metadata={"help": "Whether the target text should be lower cased."},
)


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions tests/datasets/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from loguru import logger
from transformers import TrainingArguments

from elpis.datasets.processing import create_dataset
from elpis.datasets.processing import create_local_dataset
from elpis.models.job import DataArguments, Job, ModelArguments

DATA_PATH = Path(__file__).parent.parent / "data" / "processing"


def test_create_dataset(tmp_path: Path):
def test_create_local_dataset(tmp_path: Path):
cache_dir = tmp_path / "cache"
dataset_dir = tmp_path / "dataset"
model_dir = tmp_path / "model"
Expand Down Expand Up @@ -39,6 +39,6 @@ def test_create_dataset(tmp_path: Path):
),
)

dataset = create_dataset(job)
dataset = create_local_dataset(job)
assert "train" in dataset
assert "eval" in dataset

0 comments on commit 12e2b41

Please sign in to comment.