Skip to content

Commit

Permalink
added decoding fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ASR committed Feb 21, 2024
1 parent 452f4f4 commit 29b9f7b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
12 changes: 10 additions & 2 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,16 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):
decoding_cls = OmegaConf.structured(CTCBPEDecodingConfig)
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls))
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

self.decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer,)

if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
if decoding_cfg.strategy == 'pyctcdecode':
# create separate decoders for each language
# self.decoding = [CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()),lang=l) for l in self.tokenizer.tokenizers_dict.keys()]
self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()),lang='any')
else:
self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
else:
self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer)

self._wer = WERBPE(
decoding=self.decoding,
Expand Down
9 changes: 7 additions & 2 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def transcribe(
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
verbose: bool = True,
language_id: str = None
) -> List[str]:
"""
If modify this function, please remember update transcribe_partial_audio() in
Expand Down Expand Up @@ -197,8 +198,12 @@ def transcribe(

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose):
if "multisoftmax" not in self.cfg.decoder:
language_ids = None
else:
language_ids = [language_id] * len(test_batch[0])
logits, logits_len, greedy_predictions = self.forward(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device),language_ids=language_ids
)

if logprobs:
Expand All @@ -208,7 +213,7 @@ def transcribe(
hypotheses.append(lg.cpu().numpy())
else:
current_hypotheses, all_hyp = self.decoding.ctc_decoder_predictions_tensor(
logits, decoder_lengths=logits_len, return_hypotheses=return_hypotheses,
logits, decoder_lengths=logits_len, return_hypotheses=return_hypotheses, lang_ids=language_ids
)
logits = logits.cpu()

Expand Down
5 changes: 4 additions & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
with open_dict(self.cfg.aux_ctc):
self.cfg.aux_ctc.decoding = ctc_decoding_cfg
if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in cfg.decoder:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
if ctc_decoding_cfg.strategy == 'pyctcdecode':
self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()),lang='any')
else:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
else:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer)

Expand Down

0 comments on commit 29b9f7b

Please sign in to comment.