Skip to content

Commit

Permalink
added wer calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
tahirjmakhdoomi committed Jan 22, 2025
1 parent 01f1d43 commit 0a1560e
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion examples/asr/transcribe_speech_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.core.config import TrainerConfig, hydra_runner
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero
Expand Down Expand Up @@ -187,6 +188,14 @@ def main(cfg: ParallelTranscriptionConfig):
if is_global_rank_zero():
output_file = os.path.join(cfg.output_path, cfg.output_filename)
logging.info(f"Prediction files are being aggregated in {output_file}.")

fname2obj = {}
# read the manifest and add the other details
with open(cfg.predict_ds.manifest_filepath) as reader:
lines = [json.loads(s) for s in reader.read().strip().splitlines()]
for l in lines:
fname2obj[l['audio_filepath']] = l

with open(output_file, 'w') as outf:
for rank in range(trainer.world_size):
input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json")
Expand All @@ -196,7 +205,7 @@ def main(cfg: ParallelTranscriptionConfig):
item = json.loads(line)
pred_text_list.append(item["pred_text"])
text_list.append(item["text"])
outf.write(json.dumps(item) + "\n")
outf.write(json.dumps(item | fname2obj[item['audio_filepath']]) + "\n")
samples_num += 1
os.remove(input_file)
wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer)
Expand All @@ -205,6 +214,21 @@ def main(cfg: ParallelTranscriptionConfig):
)
logging.info("{} for all predictions is {:.4f}.".format("CER" if cfg.use_cer else "WER", wer_cer))

# lets add some more details to the file

output_manifest_w_wer, total_res, _ = cal_write_wer(
pred_manifest=output_file,
gt_text_attr_name='text',
pred_text_attr_name='pred_text',
clean_groundtruth_text=False,
langid=None,
use_cer=False,
output_filename=None,
)
if output_manifest_w_wer:
logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!")
logging.info(f"{total_res}")


if __name__ == '__main__':
main()

0 comments on commit 0a1560e

Please sign in to comment.