diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py index 6c1cf111a..6382893cf 100644 --- a/examples/asr/transcribe_speech_parallel.py +++ b/examples/asr/transcribe_speech_parallel.py @@ -185,7 +185,7 @@ def main(cfg: ParallelTranscriptionConfig): pred_text_list = [] text_list = [] if is_global_rank_zero(): - output_file = os.path.join(cfg.output_path, f"predictions_all.json") + output_file = os.path.join(cfg.output_path, cfg.output_filename) logging.info(f"Prediction files are being aggregated in {output_file}.") with open(output_file, 'w') as outf: for rank in range(trainer.world_size): @@ -198,6 +198,7 @@ def main(cfg: ParallelTranscriptionConfig): text_list.append(item["text"]) outf.write(json.dumps(item) + "\n") samples_num += 1 + os.remove(input_file) wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer) logging.info( f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}."