Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate asr bleu #58

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ classifiers=[
"faiss-gpu",
]
speech = [
"fairseq2",
"hanziconv",
"inflect",
"seamless_communication",
"tnkeeh",
"torchaudio",
"num2words",
Expand Down
Empty file.
53 changes: 53 additions & 0 deletions stopes/pipelines/asr_bleu/asr_bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import asyncio
import logging
from pathlib import Path

import hydra
from omegaconf import OmegaConf

from stopes.core import utils
from stopes.pipelines.asr_bleu.compute_asr_bleu import compute_asr_bleu
from stopes.pipelines.asr_bleu.configs import AsrBleuConfig

logger = logging.getLogger("asr_bleu")


class AsrBleu:
def __init__(self, config: AsrBleuConfig):
self.config = config
self.ensure_all_dirs()
self.launcher = hydra.utils.instantiate(self.config.launcher)
self.config.launcher.cache.caching_dir = Path(self.output_dir) / "cache"
OmegaConf.save(
config=config,
f=str(self.output_dir / "asr_bleu.yaml"),
)

OmegaConf.set_readonly(self.config, True)

async def run(self):
logger.info("Computing ASRBleu on selected datasets...")
await compute_asr_bleu(
self.config.output_dir,
self.config.split,
self.config.model_name,
self.config.eval_first_pass,
self.config.dataset,
self.config.audio_format,
self.config.datasets,
self.launcher,
)

def ensure_all_dirs(self) -> None:
self.output_dir = Path(self.config.output_dir).resolve()
utils.ensure_dir(self.output_dir)


@hydra.main(config_path="conf", config_name="asr_bleu")
def main(config: AsrBleuConfig) -> None:
pipeline = AsrBleu(config)
asyncio.run(pipeline.run())


if __name__ == "__main__":
main()
97 changes: 97 additions & 0 deletions stopes/pipelines/asr_bleu/compute_asr_bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import logging
import typing as tp
from dataclasses import dataclass

from m4t_scripts.evaluate.asr_bleu import ASRBleu
from omegaconf.omegaconf import MISSING

from stopes.core.launcher import Launcher
from stopes.core.stopes_module import Requirements, StopesModule
from stopes.pipelines.asr_bleu.configs import Dataset


@dataclass
class ComputeASRBleuJob:
lang_dir: str = MISSING
split: str = MISSING
num_data_pairs: int = MISSING
model_name: str = MISSING
eval_first_pass: bool = MISSING
dataset: str = MISSING
audio_format: str = MISSING


@dataclass
class ComputeASRBleuConfig:
compute_asrbleu_jobs: tp.List[ComputeASRBleuJob] = MISSING
output_dir: str = MISSING


class ComputeASRBleu(StopesModule):
def __init__(self, config: ComputeASRBleuConfig):
super().__init__(config=config, config_class=ComputeASRBleuConfig)
self.asrbleu = ASRBleu(config.output_dir)
self.logger = logging.getLogger("stopes.asr_bleu")

def array(self):
return self.config.compute_asrbleu_jobs

def requirements(self) -> Requirements:
return Requirements(
nodes=1,
tasks_per_node=1,
gpus_per_node=1,
cpus_per_task=1,
timeout_min=24 * 60,
)

def run(
self,
iteration_value: tp.Optional[tp.Any] = None,
iteration_index: int = 0,
):
"""Runs compute_asr_bleu for each ComputeASRBleuJob"""
assert iteration_value is not None, "iteration value is null"
self.logger.info(f"Running compute_asr_bleu on {iteration_value.lang_dir}")
self.asrbleu.compute_asr_bleu(
iteration_value.lang_dir,
iteration_value.split,
iteration_value.num_data_pairs,
iteration_value.model_name,
iteration_value.eval_first_pass,
iteration_value.dataset,
iteration_value.audio_format,
)


async def compute_asr_bleu(
output_dir: str,
split: str,
model_name: str,
eval_first_pass: bool,
dataset_name: str,
audio_format: str,
datasets: tp.Dict[str, Dataset],
launcher: Launcher,
) -> tp.List[tp.Tuple[tp.Dict[str, tp.List], str, str]]:
"""
Compute ASRBleu on specified datasets
"""
compute_asrbleu_jobs = [
ComputeASRBleuJob(
lang_dir=datasets[dataset].lang_dir,
split=split,
num_data_pairs=datasets[dataset].num_data_pairs,
model_name=model_name,
eval_first_pass=eval_first_pass,
dataset=dataset_name,
audio_format=audio_format,
)
for dataset in datasets
]
compute_asrbleu_module = ComputeASRBleu(
ComputeASRBleuConfig(
compute_asrbleu_jobs=compute_asrbleu_jobs, output_dir=output_dir
)
)
await launcher.schedule(compute_asrbleu_module)
17 changes: 17 additions & 0 deletions stopes/pipelines/asr_bleu/conf/asr_bleu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
defaults:
- launcher: local
- _self_

output_dir: ???
split: "test"
model_name: "seamlessM4T_medium"
eval_first_pass: True
dataset: "fleurs"
audio_format: "n_pred.wav"

launcher:
partition: ??? # set as null if running locally
cache:
caching_dir: ${output_dir}/cache

datasets: ???
2 changes: 2 additions & 0 deletions stopes/pipelines/asr_bleu/conf/launcher/cache/file_cache.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: stopes.core.FileCache
caching_dir: /tmp/stopes_cache
8 changes: 8 additions & 0 deletions stopes/pipelines/asr_bleu/conf/launcher/local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defaults:
- cache: file_cache

_target_: stopes.core.Launcher
log_folder: executor_logs
cluster: local
partition: null
max_jobarray_jobs: 1000
8 changes: 8 additions & 0 deletions stopes/pipelines/asr_bleu/conf/launcher/submitit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defaults:
- cache: file_cache

_target_: stopes.core.Launcher
log_folder: executor_logs
cluster: slurm
partition: null
max_jobarray_jobs: 1000
20 changes: 20 additions & 0 deletions stopes/pipelines/asr_bleu/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import typing as tp
from dataclasses import dataclass


@dataclass
class Dataset:
lang_dir: str
num_data_pairs: int


@dataclass
class AsrBleuConfig:
output_dir: str
split: str
model_name: str
eval_first_pass: bool
dataset: str
audio_format: str
launcher: tp.Dict[str, tp.Any]
datasets: tp.Dict[str, Dataset]