From a1c90f425fec0b1754ba37dc8d44917273498739 Mon Sep 17 00:00:00 2001 From: Tahir Javed Date: Wed, 22 Jan 2025 16:20:26 +0530 Subject: [PATCH] added speedup if examples for a single batch are coming from a single language in RNNT --- examples/asr/transcribe_speech.py | 6 ++++-- examples/asr/transcribe_speech_parallel.py | 1 + .../asr/models/hybrid_rnnt_ctc_models.py | 4 ++-- nemo/collections/asr/modules/rnnt.py | 14 +++++++++----- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index c8372c422..f4bb2d829 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -151,11 +151,11 @@ class TranscriptionConfig: allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU) amp: bool = False amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp - matmul_precision: str = "highest" # Literal["highest", "high", "medium"] + matmul_precision: str = "medium" # Literal["highest", "high", "medium"] audio_type: str = "wav" # Recompute model transcription, even if the output folder exists with scores. - overwrite_transcripts: bool = True + overwrite_transcripts: bool = False # Decoding strategy for CTC models ctc_decoding: CTCDecodingConfig = CTCDecodingConfig() @@ -407,6 +407,8 @@ def autocast(dtype=None): override_cfg.augmentor = augmentor override_cfg.text_field = cfg.gt_text_attr_name override_cfg.lang_field = cfg.gt_lang_attr_name + override_cfg.language_id = cfg.langid + transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,) if cfg.dataset_manifest is not None: diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py index c0af8f971..6c1cf111a 100644 --- a/examples/asr/transcribe_speech_parallel.py +++ b/examples/asr/transcribe_speech_parallel.py @@ -140,6 +140,7 @@ def match_train_config(predict_ds, train_ds): @hydra_runner(config_name="TranscriptionConfig", schema=ParallelTranscriptionConfig) def main(cfg: ParallelTranscriptionConfig): + torch.set_float32_matmul_precision('medium') if cfg.model.endswith(".nemo"): logging.info("Attempting to initialize from .nemo file") model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu") diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 72c821f99..e08819721 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -648,8 +648,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False ) - - sample_id = sample_id.cpu().detach().numpy() + # breakpoint() + sample_id = sample_ids.cpu().detach().numpy() return list(zip(sample_id, best_hyp_text)) def validation_pass(self, batch, batch_idx, dataloader_idx): diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index a5e69751c..d7f9048fb 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1599,14 +1599,18 @@ def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor, language_ids= inp = self.forward_enabled_adapters(inp) if language_ids is not None: #CTEMO - # Do partial forward of joint net (skipping the final linear) for module in self.joint_net[:-1]: inp = module(inp) # [B, T, U, H] - res_single = [] - for single_inp, lang in zip(inp, language_ids): - res_single.append(self.joint_net[-1][lang](single_inp)) - res = torch.stack(res_single) + + # check if all the items in the batch have the same langauge, pass them through + if len(set(language_ids)) == 1: + res = self.joint_net[-1][language_ids[0]](inp) + else: + res_single = [] + for single_inp, lang in zip(inp, language_ids): + res_single.append(self.joint_net[-1][lang](single_inp)) + res = torch.stack(res_single) else: res = self.joint_net(inp) # [B, T, U, V + 1]