From 24e666976a34a8461565e2981b80c456b76864ad Mon Sep 17 00:00:00 2001 From: Calder Johnson Date: Sun, 10 Sep 2023 20:01:15 -0400 Subject: [PATCH 1/4] Initial commit (untested/broken) --- pyproject.toml | 2 + stopes/pipelines/asr_bleu/__init__.py | 0 stopes/pipelines/asr_bleu/asr_bleu.py | 52 ++++++++++ stopes/pipelines/asr_bleu/compute_asr_bleu.py | 96 +++++++++++++++++++ stopes/pipelines/asr_bleu/conf/asr_bleu.yaml | 17 ++++ .../conf/launcher/cache/file_cache.yaml | 2 + .../asr_bleu/conf/launcher/local.yaml | 8 ++ .../asr_bleu/conf/launcher/submitit.yaml | 8 ++ stopes/pipelines/asr_bleu/configs.py | 25 +++++ 9 files changed, 210 insertions(+) create mode 100644 stopes/pipelines/asr_bleu/__init__.py create mode 100644 stopes/pipelines/asr_bleu/asr_bleu.py create mode 100644 stopes/pipelines/asr_bleu/compute_asr_bleu.py create mode 100644 stopes/pipelines/asr_bleu/conf/asr_bleu.yaml create mode 100644 stopes/pipelines/asr_bleu/conf/launcher/cache/file_cache.yaml create mode 100644 stopes/pipelines/asr_bleu/conf/launcher/local.yaml create mode 100644 stopes/pipelines/asr_bleu/conf/launcher/submitit.yaml create mode 100644 stopes/pipelines/asr_bleu/configs.py diff --git a/pyproject.toml b/pyproject.toml index 1a0e93f..23ec357 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,8 +55,10 @@ classifiers=[ "faiss-gpu", ] speech = [ + "fairseq2", "hanziconv", "inflect", + "seamless_communication", "tnkeeh", "torchaudio", "num2words", diff --git a/stopes/pipelines/asr_bleu/__init__.py b/stopes/pipelines/asr_bleu/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stopes/pipelines/asr_bleu/asr_bleu.py b/stopes/pipelines/asr_bleu/asr_bleu.py new file mode 100644 index 0000000..a28b4ce --- /dev/null +++ b/stopes/pipelines/asr_bleu/asr_bleu.py @@ -0,0 +1,52 @@ +import asyncio +import logging +from pathlib import Path + +import hydra +from omegaconf import OmegaConf + +from stopes.core import utils +from stopes.pipelines.asr_bleu.compute_asr_bleu import compute_asr_bleu +from stopes.pipelines.asr_bleu.configs import AsrBleuConfig + +logger = logging.getLogger("asr_bleu") + + +class AsrBleu: + def __init__(self, config: AsrBleuConfig): + self.config = config + self.ensure_all_dirs() + self.launcher = hydra.utils.instantiate(self.config.launcher) + self.config.launcher.cache.caching_dir = Path(self.output_dir) / "cache" + OmegaConf.save( + config=config, + f=str(self.output_dir / "asr_bleu.yaml"), + ) + + OmegaConf.set_readonly(self.config, True) + + async def run(self): + logger.info("Computing ASRBleu on selected datasets...") + await compute_asr_bleu( + self.config.output_dir, + self.config.split, + self.config.model_name, + self.config.eval_first_pass, + self.config.dataset_name, + self.config.datasets, + self.launcher, + ) + + def ensure_all_dirs(self) -> None: + self.output_dir = Path(self.config.output_dir).resolve() + utils.ensure_dir(self.output_dir) + + +@hydra.main(config_path="conf", config_name="asr_bleu") +def main(config: AsrBleuConfig) -> None: + pipeline = AsrBleu(config) + asyncio.run(pipeline.run()) + + +if __name__ == "__main__": + main() diff --git a/stopes/pipelines/asr_bleu/compute_asr_bleu.py b/stopes/pipelines/asr_bleu/compute_asr_bleu.py new file mode 100644 index 0000000..ec26901 --- /dev/null +++ b/stopes/pipelines/asr_bleu/compute_asr_bleu.py @@ -0,0 +1,96 @@ +import logging +import typing as tp +from dataclasses import dataclass + +from m4t_scripts.evaluate.asr_bleu import ASRBleu +from omegaconf.omegaconf import MISSING + +from stopes.core.launcher import Launcher +from stopes.core.stopes_module import Requirements, StopesModule +from stopes.pipelines.asr_bleu.configs import DatasetsConfig + + +@dataclass +class ComputeASRBleuJob: + lang_dir: str = MISSING + split: str = MISSING + num_data_pairs: int = MISSING + model_name: str = MISSING + eval_first_pass: bool = MISSING + dataset: str = MISSING + audio_format: str = MISSING + + +@dataclass +class ComputeASRBleuConfig: + compute_asrbleu_jobs: tp.List[ComputeASRBleuJob] = MISSING + output_dir: str = MISSING + + +class ComputeASRBleu(StopesModule): + def __init__(self, config: ComputeASRBleuConfig): + super().__init__(config=config, config_class=ComputeASRBleuConfig) + self.asrbleu = ASRBleu(config.output_dir) + + def array(self): + return self.config.compute_asrbleu_jobs + + def requirements(self) -> Requirements: + return Requirements( + nodes=1, + tasks_per_node=1, + gpus_per_node=0, + cpus_per_task=1, + timeout_min=24 * 60, + ) + + def run( + self, + iteration_value: tp.Optional[tp.Any] = None, + iteration_index: int = 0, + ): + """Runs compute_asr_bleu for each ComputeASRBleuJob""" + assert iteration_value is not None, "iteration value is null" + self.logger = logging.getLogger("stopes.asr_bleu") + self.logger.info(f"Running compute_asr_bleu on {iteration_value.lang_dir}") + self.asrbleu.compute_asr_bleu( + iteration_value.lang_dir, + iteration_value.split, + iteration_value.num_data_pairs, + iteration_value.model_name, + iteration_value.eval_first_pass, + iteration_value.dataset, + iteration_value.audio_format, + ) + + +async def compute_asr_bleu( + output_dir: str, + split: str, + model_name: str, + eval_first_pass: bool, + dataset_name: str, + datasets_conf: DatasetsConfig, + launcher: Launcher, +) -> tp.List[tp.Tuple[tp.Dict[str, tp.List], str, str]]: + """ + Compute ASRBleu on specified datasets + """ + datasets = datasets_conf.datasets + compute_asrbleu_jobs = [ + ComputeASRBleuJob( + lang_dir=datasets[dataset].lang_dir, + split=split, + num_data_pairs=datasets[dataset].num_data_pairs, + model_name=model_name, + eval_first_pass=eval_first_pass, + dataset=dataset_name, + ) + for dataset in datasets + ] + compute_asrbleu_module = ComputeASRBleu( + ComputeASRBleuConfig( + compute_asrbleu_jobs=compute_asrbleu_jobs, output_dir=output_dir + ) + ) + await launcher.schedule(compute_asrbleu_module) diff --git a/stopes/pipelines/asr_bleu/conf/asr_bleu.yaml b/stopes/pipelines/asr_bleu/conf/asr_bleu.yaml new file mode 100644 index 0000000..e9fe6c3 --- /dev/null +++ b/stopes/pipelines/asr_bleu/conf/asr_bleu.yaml @@ -0,0 +1,17 @@ +defaults: + - launcher: local + - _self_ + +output_dir: ??? +split: "test" +model_name: "SeamlessM4T_medium" +eval_first_pass: True +dataset: "fleurs" +audio_format: "n_pred.wav" + +launcher: + partition: ??? # set as null if running locally + cache: + caching_dir: ${output_dir}/cache + +datasets: ??? \ No newline at end of file diff --git a/stopes/pipelines/asr_bleu/conf/launcher/cache/file_cache.yaml b/stopes/pipelines/asr_bleu/conf/launcher/cache/file_cache.yaml new file mode 100644 index 0000000..264f635 --- /dev/null +++ b/stopes/pipelines/asr_bleu/conf/launcher/cache/file_cache.yaml @@ -0,0 +1,2 @@ +_target_: stopes.core.FileCache +caching_dir: /tmp/stopes_cache diff --git a/stopes/pipelines/asr_bleu/conf/launcher/local.yaml b/stopes/pipelines/asr_bleu/conf/launcher/local.yaml new file mode 100644 index 0000000..d10ebdd --- /dev/null +++ b/stopes/pipelines/asr_bleu/conf/launcher/local.yaml @@ -0,0 +1,8 @@ +defaults: + - cache: file_cache + +_target_: stopes.core.Launcher +log_folder: executor_logs +cluster: local +partition: null +max_jobarray_jobs: 1000 \ No newline at end of file diff --git a/stopes/pipelines/asr_bleu/conf/launcher/submitit.yaml b/stopes/pipelines/asr_bleu/conf/launcher/submitit.yaml new file mode 100644 index 0000000..0e1b0a2 --- /dev/null +++ b/stopes/pipelines/asr_bleu/conf/launcher/submitit.yaml @@ -0,0 +1,8 @@ +defaults: + - cache: file_cache + +_target_: stopes.core.Launcher +log_folder: executor_logs +cluster: slurm +partition: null +max_jobarray_jobs: 1000 \ No newline at end of file diff --git a/stopes/pipelines/asr_bleu/configs.py b/stopes/pipelines/asr_bleu/configs.py new file mode 100644 index 0000000..50663c9 --- /dev/null +++ b/stopes/pipelines/asr_bleu/configs.py @@ -0,0 +1,25 @@ +import typing as tp +from dataclasses import dataclass + + +@dataclass +class Dataset: + lang_dir: str + num_data_pairs: int + + +@dataclass +class DatasetsConfig: + datasets: tp.Dict[str, Dataset] + + +@dataclass +class AsrBleuConfig: + output_dir: str + split: str + model_name: str + eval_first_pass: bool + dataset: str + audio_format: str + launcher: tp.Dict[str, tp.Any] + datasets: DatasetsConfig From 703727c7d83bde41fac3c39e123d3b1b320a03ef Mon Sep 17 00:00:00 2001 From: Calder Johnson Date: Mon, 11 Sep 2023 16:22:09 -0400 Subject: [PATCH 2/4] Module is functional --- stopes/pipelines/asr_bleu/asr_bleu.py | 3 ++- stopes/pipelines/asr_bleu/compute_asr_bleu.py | 13 +++++++------ stopes/pipelines/asr_bleu/conf/asr_bleu.yaml | 2 +- stopes/pipelines/asr_bleu/configs.py | 7 +------ 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/stopes/pipelines/asr_bleu/asr_bleu.py b/stopes/pipelines/asr_bleu/asr_bleu.py index a28b4ce..0a587ff 100644 --- a/stopes/pipelines/asr_bleu/asr_bleu.py +++ b/stopes/pipelines/asr_bleu/asr_bleu.py @@ -32,7 +32,8 @@ async def run(self): self.config.split, self.config.model_name, self.config.eval_first_pass, - self.config.dataset_name, + self.config.dataset, + self.config.audio_format, self.config.datasets, self.launcher, ) diff --git a/stopes/pipelines/asr_bleu/compute_asr_bleu.py b/stopes/pipelines/asr_bleu/compute_asr_bleu.py index ec26901..9096b56 100644 --- a/stopes/pipelines/asr_bleu/compute_asr_bleu.py +++ b/stopes/pipelines/asr_bleu/compute_asr_bleu.py @@ -1,4 +1,5 @@ import logging +import torch import typing as tp from dataclasses import dataclass @@ -7,8 +8,7 @@ from stopes.core.launcher import Launcher from stopes.core.stopes_module import Requirements, StopesModule -from stopes.pipelines.asr_bleu.configs import DatasetsConfig - +from stopes.pipelines.asr_bleu.configs import Dataset @dataclass class ComputeASRBleuJob: @@ -31,6 +31,7 @@ class ComputeASRBleu(StopesModule): def __init__(self, config: ComputeASRBleuConfig): super().__init__(config=config, config_class=ComputeASRBleuConfig) self.asrbleu = ASRBleu(config.output_dir) + self.logger = logging.getLogger("stopes.asr_bleu") def array(self): return self.config.compute_asrbleu_jobs @@ -39,7 +40,7 @@ def requirements(self) -> Requirements: return Requirements( nodes=1, tasks_per_node=1, - gpus_per_node=0, + gpus_per_node=1, cpus_per_task=1, timeout_min=24 * 60, ) @@ -51,7 +52,6 @@ def run( ): """Runs compute_asr_bleu for each ComputeASRBleuJob""" assert iteration_value is not None, "iteration value is null" - self.logger = logging.getLogger("stopes.asr_bleu") self.logger.info(f"Running compute_asr_bleu on {iteration_value.lang_dir}") self.asrbleu.compute_asr_bleu( iteration_value.lang_dir, @@ -70,13 +70,13 @@ async def compute_asr_bleu( model_name: str, eval_first_pass: bool, dataset_name: str, - datasets_conf: DatasetsConfig, + audio_format: str, + datasets: tp.Dict[str, Dataset], launcher: Launcher, ) -> tp.List[tp.Tuple[tp.Dict[str, tp.List], str, str]]: """ Compute ASRBleu on specified datasets """ - datasets = datasets_conf.datasets compute_asrbleu_jobs = [ ComputeASRBleuJob( lang_dir=datasets[dataset].lang_dir, @@ -85,6 +85,7 @@ async def compute_asr_bleu( model_name=model_name, eval_first_pass=eval_first_pass, dataset=dataset_name, + audio_format=audio_format, ) for dataset in datasets ] diff --git a/stopes/pipelines/asr_bleu/conf/asr_bleu.yaml b/stopes/pipelines/asr_bleu/conf/asr_bleu.yaml index e9fe6c3..2bd1cad 100644 --- a/stopes/pipelines/asr_bleu/conf/asr_bleu.yaml +++ b/stopes/pipelines/asr_bleu/conf/asr_bleu.yaml @@ -4,7 +4,7 @@ defaults: output_dir: ??? split: "test" -model_name: "SeamlessM4T_medium" +model_name: "seamlessM4T_medium" eval_first_pass: True dataset: "fleurs" audio_format: "n_pred.wav" diff --git a/stopes/pipelines/asr_bleu/configs.py b/stopes/pipelines/asr_bleu/configs.py index 50663c9..a4f85d5 100644 --- a/stopes/pipelines/asr_bleu/configs.py +++ b/stopes/pipelines/asr_bleu/configs.py @@ -8,11 +8,6 @@ class Dataset: num_data_pairs: int -@dataclass -class DatasetsConfig: - datasets: tp.Dict[str, Dataset] - - @dataclass class AsrBleuConfig: output_dir: str @@ -22,4 +17,4 @@ class AsrBleuConfig: dataset: str audio_format: str launcher: tp.Dict[str, tp.Any] - datasets: DatasetsConfig + datasets: tp.Dict[str, Dataset] From 5eebccd37d3d405e204e2476568e7a1d50304c33 Mon Sep 17 00:00:00 2001 From: Calder Johnson Date: Mon, 11 Sep 2023 16:23:31 -0400 Subject: [PATCH 3/4] Ran linter --- stopes/pipelines/asr_bleu/compute_asr_bleu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stopes/pipelines/asr_bleu/compute_asr_bleu.py b/stopes/pipelines/asr_bleu/compute_asr_bleu.py index 9096b56..0e81568 100644 --- a/stopes/pipelines/asr_bleu/compute_asr_bleu.py +++ b/stopes/pipelines/asr_bleu/compute_asr_bleu.py @@ -1,8 +1,8 @@ import logging -import torch import typing as tp from dataclasses import dataclass +import torch from m4t_scripts.evaluate.asr_bleu import ASRBleu from omegaconf.omegaconf import MISSING @@ -10,6 +10,7 @@ from stopes.core.stopes_module import Requirements, StopesModule from stopes.pipelines.asr_bleu.configs import Dataset + @dataclass class ComputeASRBleuJob: lang_dir: str = MISSING From 90b44771cd468fdf417a64cb2a95647c44ba9f0f Mon Sep 17 00:00:00 2001 From: Calder Johnson Date: Mon, 11 Sep 2023 16:44:20 -0400 Subject: [PATCH 4/4] Removed unused imports --- stopes/pipelines/asr_bleu/compute_asr_bleu.py | 1 - 1 file changed, 1 deletion(-) diff --git a/stopes/pipelines/asr_bleu/compute_asr_bleu.py b/stopes/pipelines/asr_bleu/compute_asr_bleu.py index 0e81568..10a1a20 100644 --- a/stopes/pipelines/asr_bleu/compute_asr_bleu.py +++ b/stopes/pipelines/asr_bleu/compute_asr_bleu.py @@ -2,7 +2,6 @@ import typing as tp from dataclasses import dataclass -import torch from m4t_scripts.evaluate.asr_bleu import ASRBleu from omegaconf.omegaconf import MISSING