From a441dbb8b0e8aa62532eebb2282455a87844bfc8 Mon Sep 17 00:00:00 2001 From: harrykeightley Date: Fri, 8 Sep 2023 04:43:53 +0000 Subject: [PATCH] deploy: cff43a06888f093623e69be35e23f8d2a9cde76a --- datasets/index.html | 26 +++++-- datasets/processing.html | 54 ++++++++++---- trainer/index.html | 22 +++++- trainer/job.html | 15 +++- trainer/metrics.html | 157 +++++++++++++++++++++++++++++++++++++++ trainer/trainer.html | 18 +++-- 6 files changed, 258 insertions(+), 34 deletions(-) create mode 100644 trainer/metrics.html diff --git a/datasets/index.html b/datasets/index.html index a85597f..a7de9c7 100644 --- a/datasets/index.html +++ b/datasets/index.html @@ -150,7 +150,12 @@

Parameters

processor: The processor to apply over the dataset """ - def prepare_dataset(batch: Dict) -> Dict[str, List]: + logger.debug(f"Dataset pre prep: {dataset}") + logger.debug(f"Dataset[train] pre prep: {dataset['train']['transcript']}") + logger.debug(f"Tokenizer vocab: {processor.tokenizer.vocab}") # type: ignore + + def _prepare_dataset(batch: Dict) -> Dict[str, List]: + # Also from https://huggingface.co/blog/fine-tune-xlsr-wav2vec2 audio = batch["audio"] batch["input_values"] = processor( @@ -158,16 +163,23 @@

Parameters

).input_values[0] batch["input_length"] = len(batch["input_values"]) - with processor.as_target_processor(): - batch["labels"] = processor(batch["transcript"]).input_ids + batch["labels"] = processor(text=batch["transcript"]).input_ids return batch - return dataset.map( - prepare_dataset, - remove_columns=dataset.column_names["train"], + column_names = [dataset.column_names[key] for key in dataset.column_names.keys()] + # flatten + columns_to_remove = list(chain.from_iterable(column_names)) + + dataset = dataset.map( + _prepare_dataset, + remove_columns=columns_to_remove, num_proc=PROCESSOR_COUNT, - ) + ) + + logger.debug(f"Dataset post prep: {dataset}") + logger.debug(f"Training labels: {dataset['train']['labels']}") + return dataset
diff --git a/datasets/processing.html b/datasets/processing.html index 2d97427..e0c08df 100644 --- a/datasets/processing.html +++ b/datasets/processing.html @@ -27,10 +27,12 @@

Module elpis.datasets.processing

Expand source code
import os
+from itertools import chain
 from pathlib import Path
 from typing import Any, Dict, List, Optional
 
 from datasets import Audio, DatasetDict, load_dataset
+from loguru import logger
 from transformers import Wav2Vec2Processor
 
 PROCESSOR_COUNT = 4
@@ -88,7 +90,12 @@ 

Module elpis.datasets.processing

processor: The processor to apply over the dataset """ - def prepare_dataset(batch: Dict) -> Dict[str, List]: + logger.debug(f"Dataset pre prep: {dataset}") + logger.debug(f"Dataset[train] pre prep: {dataset['train']['transcript']}") + logger.debug(f"Tokenizer vocab: {processor.tokenizer.vocab}") # type: ignore + + def _prepare_dataset(batch: Dict) -> Dict[str, List]: + # Also from https://huggingface.co/blog/fine-tune-xlsr-wav2vec2 audio = batch["audio"] batch["input_values"] = processor( @@ -96,16 +103,23 @@

Module elpis.datasets.processing

).input_values[0] batch["input_length"] = len(batch["input_values"]) - with processor.as_target_processor(): - batch["labels"] = processor(batch["transcript"]).input_ids + batch["labels"] = processor(text=batch["transcript"]).input_ids return batch - return dataset.map( - prepare_dataset, - remove_columns=dataset.column_names["train"], + column_names = [dataset.column_names[key] for key in dataset.column_names.keys()] + # flatten + columns_to_remove = list(chain.from_iterable(column_names)) + + dataset = dataset.map( + _prepare_dataset, + remove_columns=columns_to_remove, num_proc=PROCESSOR_COUNT, - )
+ ) + + logger.debug(f"Dataset post prep: {dataset}") + logger.debug(f"Training labels: {dataset['train']['labels']}") + return dataset
@@ -195,7 +209,12 @@

Parameters

processor: The processor to apply over the dataset """ - def prepare_dataset(batch: Dict) -> Dict[str, List]: + logger.debug(f"Dataset pre prep: {dataset}") + logger.debug(f"Dataset[train] pre prep: {dataset['train']['transcript']}") + logger.debug(f"Tokenizer vocab: {processor.tokenizer.vocab}") # type: ignore + + def _prepare_dataset(batch: Dict) -> Dict[str, List]: + # Also from https://huggingface.co/blog/fine-tune-xlsr-wav2vec2 audio = batch["audio"] batch["input_values"] = processor( @@ -203,16 +222,23 @@

Parameters

).input_values[0] batch["input_length"] = len(batch["input_values"]) - with processor.as_target_processor(): - batch["labels"] = processor(batch["transcript"]).input_ids + batch["labels"] = processor(text=batch["transcript"]).input_ids return batch - return dataset.map( - prepare_dataset, - remove_columns=dataset.column_names["train"], + column_names = [dataset.column_names[key] for key in dataset.column_names.keys()] + # flatten + columns_to_remove = list(chain.from_iterable(column_names)) + + dataset = dataset.map( + _prepare_dataset, + remove_columns=columns_to_remove, num_proc=PROCESSOR_COUNT, - ) + ) + + logger.debug(f"Dataset post prep: {dataset}") + logger.debug(f"Training labels: {dataset['train']['labels']}") + return dataset diff --git a/trainer/index.html b/trainer/index.html index fef0731..dd319bc 100644 --- a/trainer/index.html +++ b/trainer/index.html @@ -43,6 +43,10 @@

Sub-modules

+
elpis.trainer.metrics
+
+
+
elpis.trainer.trainer
@@ -62,7 +66,7 @@

Functions

def train(job: TrainingJob, output_dir: pathlib.Path, dataset_dir: pathlib.Path, cache_dir: Optional[pathlib.Path] = None, log_file: Optional[pathlib.Path] = None) ‑> pathlib.Path
-

Trains a model for use in transcription.

+

Fine-tunes a model for use in transcription.

Parameters

job: Info about the training job, e.g. training options. output_dir: Where to save the trained model. @@ -82,7 +86,7 @@

Returns

cache_dir: Optional[Path] = None, log_file: Optional[Path] = None, ) -> Path: - """Trains a model for use in transcription. + """Fine-tunes a model for use in transcription. Parameters: job: Info about the training job, e.g. training options. @@ -125,6 +129,7 @@

Returns

eval_dataset=dataset["test"], # type: ignore tokenizer=processor.feature_extractor, data_collator=data_collator, + compute_metrics=create_metrics(job.metrics, processor), ) logger.info(f"Begin training model...") @@ -138,9 +143,9 @@

Returns

logger.info(f"Model written to disk.") metrics = trainer.evaluate() + logger.info("==== Metrics ====") trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) - logger.info("==== Metrics ====") logger.info(metrics) return output_dir @@ -153,7 +158,7 @@

Classes

class TrainingJob -(model_name: str, dataset_name: str, options: TrainingOptions, status: TrainingStatus = TrainingStatus.WAITING, base_model: str = 'facebook/wav2vec2-base-960h', sampling_rate: int = 16000) +(model_name: str, dataset_name: str, options: TrainingOptions, status: TrainingStatus = TrainingStatus.WAITING, base_model: str = 'facebook/wav2vec2-base-960h', sampling_rate: int = 16000, metrics: Tuple[str, ...] = ('wer', 'cer'))

A class representing a training job for a model

@@ -171,6 +176,7 @@

Classes

status: TrainingStatus = TrainingStatus.WAITING base_model: str = BASE_MODEL sampling_rate: int = SAMPLING_RATE + metrics: Tuple[str, ...] = METRICS def to_training_args(self, output_dir: Path, **kwargs) -> TrainingArguments: return TrainingArguments( @@ -205,6 +211,7 @@

Classes

status=TrainingStatus(data.get("status", TrainingStatus.WAITING)), base_model=data.get("base_model", BASE_MODEL), sampling_rate=data.get("sampling_rate", SAMPLING_RATE), + metrics=data.get("metrics", METRICS), ) def to_dict(self) -> Dict[str, Any]: @@ -222,6 +229,10 @@

Class variables

+
var metrics : Tuple[str, ...]
+
+
+
var model_name : str
@@ -259,6 +270,7 @@

Static methods

status=TrainingStatus(data.get("status", TrainingStatus.WAITING)), base_model=data.get("base_model", BASE_MODEL), sampling_rate=data.get("sampling_rate", SAMPLING_RATE), + metrics=data.get("metrics", METRICS), )
@@ -477,6 +489,7 @@

Index

@@ -494,6 +507,7 @@

base_model
  • dataset_name
  • from_dict
  • +
  • metrics
  • model_name
  • options
  • sampling_rate
  • diff --git a/trainer/job.html b/trainer/job.html index 9155b68..47c06d6 100644 --- a/trainer/job.html +++ b/trainer/job.html @@ -31,13 +31,14 @@

    Module elpis.trainer.job

    from dataclasses import dataclass, fields from enum import Enum from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Tuple import torch from transformers import TrainingArguments BASE_MODEL = "facebook/wav2vec2-base-960h" SAMPLING_RATE = 16_000 +METRICS = ("wer", "cer") class TrainingStatus(Enum): @@ -80,6 +81,7 @@

    Module elpis.trainer.job

    status: TrainingStatus = TrainingStatus.WAITING base_model: str = BASE_MODEL sampling_rate: int = SAMPLING_RATE + metrics: Tuple[str, ...] = METRICS def to_training_args(self, output_dir: Path, **kwargs) -> TrainingArguments: return TrainingArguments( @@ -114,6 +116,7 @@

    Module elpis.trainer.job

    status=TrainingStatus(data.get("status", TrainingStatus.WAITING)), base_model=data.get("base_model", BASE_MODEL), sampling_rate=data.get("sampling_rate", SAMPLING_RATE), + metrics=data.get("metrics", METRICS), ) def to_dict(self) -> Dict[str, Any]: @@ -133,7 +136,7 @@

    Classes

    class TrainingJob -(model_name: str, dataset_name: str, options: TrainingOptions, status: TrainingStatus = TrainingStatus.WAITING, base_model: str = 'facebook/wav2vec2-base-960h', sampling_rate: int = 16000) +(model_name: str, dataset_name: str, options: TrainingOptions, status: TrainingStatus = TrainingStatus.WAITING, base_model: str = 'facebook/wav2vec2-base-960h', sampling_rate: int = 16000, metrics: Tuple[str, ...] = ('wer', 'cer'))

    A class representing a training job for a model

    @@ -151,6 +154,7 @@

    Classes

    status: TrainingStatus = TrainingStatus.WAITING base_model: str = BASE_MODEL sampling_rate: int = SAMPLING_RATE + metrics: Tuple[str, ...] = METRICS def to_training_args(self, output_dir: Path, **kwargs) -> TrainingArguments: return TrainingArguments( @@ -185,6 +189,7 @@

    Classes

    status=TrainingStatus(data.get("status", TrainingStatus.WAITING)), base_model=data.get("base_model", BASE_MODEL), sampling_rate=data.get("sampling_rate", SAMPLING_RATE), + metrics=data.get("metrics", METRICS), ) def to_dict(self) -> Dict[str, Any]: @@ -202,6 +207,10 @@

    Class variables

    +
    var metrics : Tuple[str, ...]
    +
    +
    +
    var model_name : str
    @@ -239,6 +248,7 @@

    Static methods

    status=TrainingStatus(data.get("status", TrainingStatus.WAITING)), base_model=data.get("base_model", BASE_MODEL), sampling_rate=data.get("sampling_rate", SAMPLING_RATE), + metrics=data.get("metrics", METRICS), )
    @@ -461,6 +471,7 @@

    base_model
  • dataset_name
  • from_dict
  • +
  • metrics
  • model_name
  • options
  • sampling_rate
  • diff --git a/trainer/metrics.html b/trainer/metrics.html new file mode 100644 index 0000000..3df4ed7 --- /dev/null +++ b/trainer/metrics.html @@ -0,0 +1,157 @@ + + + + + + +elpis.trainer.metrics API documentation + + + + + + + + + + + +
    +
    +
    +

    Module elpis.trainer.metrics

    +
    +
    +
    + +Expand source code + +
    from typing import Callable, Dict, Optional, Sequence
    +
    +import evaluate
    +import numpy as np
    +from loguru import logger
    +from transformers import EvalPrediction, Wav2Vec2Processor
    +
    +
    +def create_metrics(
    +    metric_names: Sequence[str], processor: Wav2Vec2Processor
    +) -> Optional[Callable[[EvalPrediction], Dict]]:
    +    # Handle metrics
    +    if len(metric_names) == 0:
    +        return
    +
    +    # Note: was using evaluate.combine but was having many unexpected errors.
    +    metrics = {name: evaluate.load(name) for name in metric_names}
    +
    +    def compute_metrics(pred: EvalPrediction) -> Dict:
    +        # taken from https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
    +        pred_logits = pred.predictions
    +
    +        pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id  # type: ignore
    +
    +        # Taken from: https://discuss.huggingface.co/t/code-review-compute-metrics-for-wer-with-wav2vec2processorwithlm/16841/3
    +        if type(processor).__name__ == "Wav2Vec2ProcessorWithLM":
    +            pred_str = processor.batch_decode(pred_logits).text
    +        else:
    +            pred_ids = np.argmax(pred_logits, axis=-1)
    +            pred_str = processor.batch_decode(pred_ids)
    +
    +        # We do not want to group tokens when computing the metrics
    +        label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    +
    +        logger.debug(f"METRICS->pred: {pred_str} label:{label_str}")
    +
    +        result = {
    +            name: metric.compute(predictions=pred_str, references=label_str)
    +            for name, metric in metrics.items()
    +        }
    +        logger.debug(f"Metrics Result: {result}")
    +        return result
    +
    +    return compute_metrics
    +
    +
    +
    +
    +
    +
    +
    +

    Functions

    +
    +
    +def create_metrics(metric_names: Sequence[str], processor: transformers.models.wav2vec2.processing_wav2vec2.Wav2Vec2Processor) ‑> Optional[Callable[[transformers.trainer_utils.EvalPrediction], Dict]] +
    +
    +
    +
    + +Expand source code + +
    def create_metrics(
    +    metric_names: Sequence[str], processor: Wav2Vec2Processor
    +) -> Optional[Callable[[EvalPrediction], Dict]]:
    +    # Handle metrics
    +    if len(metric_names) == 0:
    +        return
    +
    +    # Note: was using evaluate.combine but was having many unexpected errors.
    +    metrics = {name: evaluate.load(name) for name in metric_names}
    +
    +    def compute_metrics(pred: EvalPrediction) -> Dict:
    +        # taken from https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
    +        pred_logits = pred.predictions
    +
    +        pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id  # type: ignore
    +
    +        # Taken from: https://discuss.huggingface.co/t/code-review-compute-metrics-for-wer-with-wav2vec2processorwithlm/16841/3
    +        if type(processor).__name__ == "Wav2Vec2ProcessorWithLM":
    +            pred_str = processor.batch_decode(pred_logits).text
    +        else:
    +            pred_ids = np.argmax(pred_logits, axis=-1)
    +            pred_str = processor.batch_decode(pred_ids)
    +
    +        # We do not want to group tokens when computing the metrics
    +        label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    +
    +        logger.debug(f"METRICS->pred: {pred_str} label:{label_str}")
    +
    +        result = {
    +            name: metric.compute(predictions=pred_str, references=label_str)
    +            for name, metric in metrics.items()
    +        }
    +        logger.debug(f"Metrics Result: {result}")
    +        return result
    +
    +    return compute_metrics
    +
    +
    +
    +
    +
    +
    +
    + +
    + + + \ No newline at end of file diff --git a/trainer/trainer.html b/trainer/trainer.html index 74211d3..b8add76 100644 --- a/trainer/trainer.html +++ b/trainer/trainer.html @@ -28,14 +28,16 @@

    Module elpis.trainer.trainer

    from contextlib import nullcontext
     from pathlib import Path
    -from typing import Optional
    +from typing import Dict, Optional
     
     from loguru import logger
    -from transformers import AutoModelForCTC, AutoProcessor, Trainer
    +from tokenizers import Tokenizer
    +from transformers import AutoModelForCTC, AutoProcessor, EvalPrediction, Trainer
     
     from elpis.datasets import create_dataset, prepare_dataset
     from elpis.trainer.data_collator import DataCollatorCTCWithPadding
     from elpis.trainer.job import TrainingJob
    +from elpis.trainer.metrics import create_metrics
     from elpis.trainer.utils import log_to_file
     
     
    @@ -46,7 +48,7 @@ 

    Module elpis.trainer.trainer

    cache_dir: Optional[Path] = None, log_file: Optional[Path] = None, ) -> Path: - """Trains a model for use in transcription. + """Fine-tunes a model for use in transcription. Parameters: job: Info about the training job, e.g. training options. @@ -89,6 +91,7 @@

    Module elpis.trainer.trainer

    eval_dataset=dataset["test"], # type: ignore tokenizer=processor.feature_extractor, data_collator=data_collator, + compute_metrics=create_metrics(job.metrics, processor), ) logger.info(f"Begin training model...") @@ -102,9 +105,9 @@

    Module elpis.trainer.trainer

    logger.info(f"Model written to disk.") metrics = trainer.evaluate() + logger.info("==== Metrics ====") trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) - logger.info("==== Metrics ====") logger.info(metrics) return output_dir
    @@ -121,7 +124,7 @@

    Functions

    def train(job: TrainingJob, output_dir: pathlib.Path, dataset_dir: pathlib.Path, cache_dir: Optional[pathlib.Path] = None, log_file: Optional[pathlib.Path] = None) ‑> pathlib.Path
    -

    Trains a model for use in transcription.

    +

    Fine-tunes a model for use in transcription.

    Parameters

    job: Info about the training job, e.g. training options. output_dir: Where to save the trained model. @@ -141,7 +144,7 @@

    Returns

    cache_dir: Optional[Path] = None, log_file: Optional[Path] = None, ) -> Path: - """Trains a model for use in transcription. + """Fine-tunes a model for use in transcription. Parameters: job: Info about the training job, e.g. training options. @@ -184,6 +187,7 @@

    Returns

    eval_dataset=dataset["test"], # type: ignore tokenizer=processor.feature_extractor, data_collator=data_collator, + compute_metrics=create_metrics(job.metrics, processor), ) logger.info(f"Begin training model...") @@ -197,9 +201,9 @@

    Returns

    logger.info(f"Model written to disk.") metrics = trainer.evaluate() + logger.info("==== Metrics ====") trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) - logger.info("==== Metrics ====") logger.info(metrics) return output_dir