Skip to content

Commit

Permalink
added speedup if examples for a single batch are coming from a single…
Browse files Browse the repository at this point in the history
… language in RNNT
  • Loading branch information
tahirjmakhdoomi committed Jan 22, 2025
1 parent d498ad1 commit a1c90f4
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
6 changes: 4 additions & 2 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ class TranscriptionConfig:
allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU)
amp: bool = False
amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp
matmul_precision: str = "highest" # Literal["highest", "high", "medium"]
matmul_precision: str = "medium" # Literal["highest", "high", "medium"]
audio_type: str = "wav"

# Recompute model transcription, even if the output folder exists with scores.
overwrite_transcripts: bool = True
overwrite_transcripts: bool = False

# Decoding strategy for CTC models
ctc_decoding: CTCDecodingConfig = CTCDecodingConfig()
Expand Down Expand Up @@ -407,6 +407,8 @@ def autocast(dtype=None):
override_cfg.augmentor = augmentor
override_cfg.text_field = cfg.gt_text_attr_name
override_cfg.lang_field = cfg.gt_lang_attr_name
override_cfg.language_id = cfg.langid

transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,)

if cfg.dataset_manifest is not None:
Expand Down
1 change: 1 addition & 0 deletions examples/asr/transcribe_speech_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def match_train_config(predict_ds, train_ds):

@hydra_runner(config_name="TranscriptionConfig", schema=ParallelTranscriptionConfig)
def main(cfg: ParallelTranscriptionConfig):
torch.set_float32_matmul_precision('medium')
if cfg.model.endswith(".nemo"):
logging.info("Attempting to initialize from .nemo file")
model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu")
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,8 +648,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
)

sample_id = sample_id.cpu().detach().numpy()
# breakpoint()
sample_id = sample_ids.cpu().detach().numpy()
return list(zip(sample_id, best_hyp_text))

def validation_pass(self, batch, batch_idx, dataloader_idx):
Expand Down
14 changes: 9 additions & 5 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,14 +1599,18 @@ def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor, language_ids=
inp = self.forward_enabled_adapters(inp)

if language_ids is not None: #CTEMO

# Do partial forward of joint net (skipping the final linear)
for module in self.joint_net[:-1]:
inp = module(inp) # [B, T, U, H]
res_single = []
for single_inp, lang in zip(inp, language_ids):
res_single.append(self.joint_net[-1][lang](single_inp))
res = torch.stack(res_single)

# check if all the items in the batch have the same langauge, pass them through
if len(set(language_ids)) == 1:
res = self.joint_net[-1][language_ids[0]](inp)
else:
res_single = []
for single_inp, lang in zip(inp, language_ids):
res_single.append(self.joint_net[-1][lang](single_inp))
res = torch.stack(res_single)
else:
res = self.joint_net(inp) # [B, T, U, V + 1]

Expand Down

0 comments on commit a1c90f4

Please sign in to comment.