diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 5f380619d..d56c12704 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -124,6 +124,8 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + logprobs: bool = False, + language_id: str = None, override_config: Optional[TranscribeConfig] = None, ) -> TranscriptionReturnType: """ @@ -159,6 +161,7 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, + language_id=language_id, override_config=override_config, ) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 4da9bf826..99912258b 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -434,7 +434,7 @@ def change_vocabulary( logging.info(f"Changed tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None): + def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None, lang_id: str=None): """ Changes decoding strategy used during RNNT decoding process. Args: @@ -458,7 +458,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: #CTEMO self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()) + decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()), + # lang_id=lang_id ) else: self.decoding = RNNTBPEDecoding( @@ -488,7 +489,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type self.cur_decoder = "rnnt" logging.info(f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}") - + elif decoder_type == 'ctc': if not hasattr(self, 'ctc_decoding'): raise ValueError("The model does not have the ctc_decoding module and does not support ctc decoding.") @@ -503,9 +504,9 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: #CTEMO - self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys())) + self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()), lang_id=lang_id) else: - self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer) + self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer, lang_id=lang_id) self.ctc_wer = WER( decoding=self.ctc_decoding, diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index ecdf80d24..4a3571eb9 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -172,7 +172,7 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): if "multisoftmax" not in self.cfg.decoder: language_ids = None else: - language_ids = [language_id] * len(batch[0]) + language_ids = [trcfg.language_id] * len(batch[0]) logits = self.ctc_decoder(encoder_output=encoded, language_ids=language_ids) output = dict(logits=logits, encoded_len=encoded_len, language_ids=language_ids) diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index e19518533..19c468ea4 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -283,6 +283,7 @@ def transcribe( verbose=verbose, override_config=override_config, logprobs=logprobs, + language_id=language_id, # Additional arguments partial_hypothesis=partial_hypothesis, ) @@ -880,12 +881,18 @@ def _transcribe_output_processing( ) -> Tuple[List['Hypothesis'], List['Hypothesis']]: encoded = outputs.pop('encoded') encoded_len = outputs.pop('encoded_len') + + if "multisoftmax" not in self.cfg.decoder: + language_ids = None + else: + language_ids = [trcfg.language_id] * len(encoded) best_hyp, all_hyp = self.decoding.rnnt_decoder_predictions_tensor( encoded, encoded_len, return_hypotheses=trcfg.return_hypotheses, partial_hypotheses=trcfg.partial_hypothesis, + lang_ids=language_ids, ) # cleanup memory diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 0ad0a3482..a5e69751c 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1597,6 +1597,7 @@ def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor, language_ids= # Forward adapter modules on joint hidden if self.is_adapter_available(): inp = self.forward_enabled_adapters(inp) + if language_ids is not None: #CTEMO # Do partial forward of joint net (skipping the final linear) diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index 449374d0a..54b953015 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -61,6 +61,7 @@ class TranscribeConfig: augmentor: Optional[DictConfig] = None verbose: bool = True logprobs: bool = False + language_id: str = None # Utility partial_hypothesis: Optional[List[Any]] = None @@ -196,6 +197,7 @@ def transcribe( verbose: bool = True, override_config: Optional[TranscribeConfig] = None, logprobs: bool = False, + language_id: str = None, **config_kwargs, ) -> GenericTranscriptionType: """ @@ -245,6 +247,7 @@ def transcribe( augmentor=augmentor, verbose=verbose, logprobs=logprobs, + language_id=language_id, **config_kwargs, ) else: diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index 178c62976..f24cdc40a 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -1211,7 +1211,7 @@ class CTCBPEDecoding(AbstractCTCDecoding): tokenizer: NeMo tokenizer object, which inherits from TokenizerSpec. """ - def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None): #CTEMO + def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None, lang_id: str = None): #CTEMO if blank_id is None: blank_id = tokenizer.tokenizer.vocab_size self.tokenizer = tokenizer @@ -1223,7 +1223,12 @@ def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None): #CT if hasattr(self.tokenizer.tokenizer, 'get_vocab'): vocab_dict = self.tokenizer.tokenizer.get_vocab() if isinstance(self.tokenizer.tokenizer, DummyTokenizer): # AggregateTokenizer.DummyTokenizer - vocab = vocab_dict + if lang_id is not None: + tokenizer = self.tokenizer.tokenizers_dict[lang_id] + vocab_dict = tokenizer.tokenizer.get_vocab() + vocab = list(vocab_dict.keys()) + else: + vocab = vocab_dict else: vocab = list(vocab_dict.keys()) self.decoding.set_vocabulary(vocab) diff --git a/nemo/collections/common/tokenizers/multilingual_tokenizer.py b/nemo/collections/common/tokenizers/multilingual_tokenizer.py index 6357552db..1b4e66ed3 100644 --- a/nemo/collections/common/tokenizers/multilingual_tokenizer.py +++ b/nemo/collections/common/tokenizers/multilingual_tokenizer.py @@ -17,23 +17,12 @@ import numpy as np from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.common.tokenizers.aggregate_tokenizer import DummyTokenizer from nemo.utils import logging __all__ = ['MultilingualTokenizer'] -class DummyTokenizer: - def __init__(self, vocab): - self.vocab = vocab - self.vocab_size = len(vocab) - - # minimum compatibility - # since all the monolingual tokenizers have a vocab - # additional methods could be added here - def get_vocab(self): - return self.vocab - - class MultilingualTokenizer(TokenizerSpec): ''' MultilingualTokenizer, allowing one to combine multiple regular monolongual tokenizers into one tokenizer.