diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py index 6382893cf..0ae197f40 100644 --- a/examples/asr/transcribe_speech_parallel.py +++ b/examples/asr/transcribe_speech_parallel.py @@ -84,6 +84,7 @@ from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer from nemo.core.config import TrainerConfig, hydra_runner from nemo.utils import logging from nemo.utils.get_rank import is_global_rank_zero @@ -187,6 +188,14 @@ def main(cfg: ParallelTranscriptionConfig): if is_global_rank_zero(): output_file = os.path.join(cfg.output_path, cfg.output_filename) logging.info(f"Prediction files are being aggregated in {output_file}.") + + fname2obj = {} + # read the manifest and add the other details + with open(cfg.predict_ds.manifest_filepath) as reader: + lines = [json.loads(s) for s in reader.read().strip().splitlines()] + for l in lines: + fname2obj[l['audio_filepath']] = l + with open(output_file, 'w') as outf: for rank in range(trainer.world_size): input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json") @@ -196,7 +205,7 @@ def main(cfg: ParallelTranscriptionConfig): item = json.loads(line) pred_text_list.append(item["pred_text"]) text_list.append(item["text"]) - outf.write(json.dumps(item) + "\n") + outf.write(json.dumps(item | fname2obj[item['audio_filepath']]) + "\n") samples_num += 1 os.remove(input_file) wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer) @@ -205,6 +214,21 @@ def main(cfg: ParallelTranscriptionConfig): ) logging.info("{} for all predictions is {:.4f}.".format("CER" if cfg.use_cer else "WER", wer_cer)) + # lets add some more details to the file + + output_manifest_w_wer, total_res, _ = cal_write_wer( + pred_manifest=output_file, + gt_text_attr_name='text', + pred_text_attr_name='pred_text', + clean_groundtruth_text=False, + langid=None, + use_cer=False, + output_filename=None, + ) + if output_manifest_w_wer: + logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!") + logging.info(f"{total_res}") + if __name__ == '__main__': main()