Skip to content

Commit

Permalink
added output path for parallel inference
Browse files Browse the repository at this point in the history
  • Loading branch information
tahirjmakhdoomi committed Jan 22, 2025
1 parent a1c90f4 commit 01f1d43
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion examples/asr/transcribe_speech_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def main(cfg: ParallelTranscriptionConfig):
pred_text_list = []
text_list = []
if is_global_rank_zero():
output_file = os.path.join(cfg.output_path, f"predictions_all.json")
output_file = os.path.join(cfg.output_path, cfg.output_filename)
logging.info(f"Prediction files are being aggregated in {output_file}.")
with open(output_file, 'w') as outf:
for rank in range(trainer.world_size):
Expand All @@ -198,6 +198,7 @@ def main(cfg: ParallelTranscriptionConfig):
text_list.append(item["text"])
outf.write(json.dumps(item) + "\n")
samples_num += 1
os.remove(input_file)
wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer)
logging.info(
f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}."
Expand Down

0 comments on commit 01f1d43

Please sign in to comment.