diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index 58cd3630e..ffb56478c 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -59,22 +59,27 @@ def _speech_collate_fn(batch, pad_id): assumes the signals are 1d torch tensors (i.e. mono audio). """ packed_batch = list(zip(*batch)) - if len(packed_batch) == 5: + if len(packed_batch) == 6: # has language ids + _, audio_lengths, _, tokens_lengths, sample_ids, language_ids = packed_batch + elif len(packed_batch) == 5: # has sample ids + language_ids = None _, audio_lengths, _, tokens_lengths, sample_ids = packed_batch elif len(packed_batch) == 4: - sample_ids = None + sample_ids, language_ids = None, None _, audio_lengths, _, tokens_lengths = packed_batch else: - raise ValueError("Expects 4 or 5 tensors in the batch!") + raise ValueError("Expects 4 or 5 or 6 tensors in the batch!") max_audio_len = 0 has_audio = audio_lengths[0] is not None if has_audio: max_audio_len = max(audio_lengths).item() max_tokens_len = max(tokens_lengths).item() - + audio_signal, tokens = [], [] for b in batch: - if len(b) == 5: + if len(b) == 6: + sig, sig_len, tokens_i, tokens_i_len, _, _ = b + elif len(b) == 5: sig, sig_len, tokens_i, tokens_i_len, _ = b else: sig, sig_len, tokens_i, tokens_i_len = b @@ -97,12 +102,14 @@ def _speech_collate_fn(batch, pad_id): audio_signal, audio_lengths = None, None tokens = torch.stack(tokens) tokens_lengths = torch.stack(tokens_lengths) - if sample_ids is None: - return audio_signal, audio_lengths, tokens, tokens_lengths - else: + if language_ids is not None: + sample_ids = torch.tensor(sample_ids, dtype=torch.int32) + return audio_signal, audio_lengths, tokens, tokens_lengths, sample_ids, list(language_ids) + elif sample_ids is not None: sample_ids = torch.tensor(sample_ids, dtype=torch.int32) return audio_signal, audio_lengths, tokens, tokens_lengths, sample_ids - + else: + return audio_signal, audio_lengths, tokens, tokens_lengths class ASRManifestProcessor: """ @@ -424,6 +431,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'transcripts': NeuralType(('B', 'T'), LabelsType()), 'transcript_length': NeuralType(tuple('B'), LengthsType()), 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), + 'language_id': [NeuralType(('B'), StringType(), optional=True)], } def __init__( @@ -441,6 +449,7 @@ def __init__( eos_id: Optional[int] = None, pad_id: int = 0, return_sample_id: bool = False, + return_language_id: bool = False, channel_selector: Optional[ChannelSelectorType] = None, ): if type(manifest_filepath) == str: @@ -462,6 +471,7 @@ def __init__( self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim self.return_sample_id = return_sample_id + self.return_language_id = return_language_id self.channel_selector = channel_selector def get_manifest_sample(self, sample_id): @@ -486,8 +496,10 @@ def __getitem__(self, index): t, tl = self.manifest_processor.process_text_by_sample(sample=sample) - if self.return_sample_id: - output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index + if self.return_language_id: + output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index, sample.lang + elif self.return_sample_id: + output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index else: output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long() @@ -530,6 +542,7 @@ class AudioToCharDataset(_AudioTextDataset): bos_id: Id of beginning of sequence symbol to append if not None eos_id: Id of end of sequence symbol to append if not None return_sample_id (bool): whether to return the sample_id as a part of each sample + return_language_id (bool): whether to return the language_id as a part of each sample channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. """ @@ -543,6 +556,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'transcripts': NeuralType(('B', 'T'), LabelsType()), 'transcript_length': NeuralType(tuple('B'), LengthsType()), 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), + 'language_id': [NeuralType(('B'), StringType(), optional=True)], } def __init__( @@ -564,6 +578,7 @@ def __init__( pad_id: int = 0, parser: Union[str, Callable] = 'en', return_sample_id: bool = False, + return_language_id: bool = False, channel_selector: Optional[ChannelSelectorType] = None, ): self.labels = labels @@ -586,6 +601,7 @@ def __init__( eos_id=eos_id, pad_id=pad_id, return_sample_id=return_sample_id, + return_language_id=return_language_id, channel_selector=channel_selector, ) @@ -624,6 +640,7 @@ class AudioToBPEDataset(_AudioTextDataset): use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS] tokens to beginning and ending of speech respectively. return_sample_id (bool): whether to return the sample_id as a part of each sample + return_language_id (bool): whether to return the language_id as a part of each sample channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. """ @@ -637,6 +654,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'transcripts': NeuralType(('B', 'T'), LabelsType()), 'transcript_length': NeuralType(tuple('B'), LengthsType()), 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), + 'language_id': [NeuralType(('B'), StringType(), optional=True)], } def __init__( @@ -652,6 +670,7 @@ def __init__( trim: bool = False, use_start_end_token: bool = True, return_sample_id: bool = False, + return_language_id: bool = False, channel_selector: Optional[ChannelSelectorType] = None, ): if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0: @@ -671,7 +690,7 @@ def __init__( class TokenizerWrapper: def __init__(self, tokenizer): - if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer): + if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer) or isinstance(tokenizer, tokenizers.multilingual_tokenizer.MultilingualTokenizer): self.is_aggregate = True else: self.is_aggregate = False @@ -701,6 +720,7 @@ def __call__(self, *args): pad_id=pad_id, trim=trim, return_sample_id=return_sample_id, + return_language_id=return_language_id, channel_selector=channel_selector, ) diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index 14e8dea19..5aabdfc7a 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -151,6 +151,7 @@ def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None) trim=config.get('trim_silence', False), parser=config.get('parser', 'en'), return_sample_id=config.get('return_sample_id', False), + return_language_id=config.get('return_language_id', False), channel_selector=config.get('channel_selector', None), ) return dataset @@ -231,6 +232,7 @@ def get_bpe_dataset( trim=config.get('trim_silence', False), use_start_end_token=config.get('use_start_end_token', True), return_sample_id=config.get('return_sample_id', False), + return_language_id=config.get('return_language_id', False), channel_selector=config.get('channel_selector', None), ) return dataset @@ -630,7 +632,8 @@ def get_audio_to_text_bpe_dataset_from_config( logging.warning(f"`concat_sampling_probabilities` need to sum to 1. Config: {config}") return None - shuffle = config['shuffle'] + if config.get('shuffle', False): + shuffle = False device = 'gpu' if torch.cuda.is_available() else 'cpu' if config.get('use_dali', False): device_id = local_rank if device == 'gpu' else None diff --git a/nemo/collections/asr/losses/ssl_losses/contrastive.py b/nemo/collections/asr/losses/ssl_losses/contrastive.py index bab691913..4ca873a17 100644 --- a/nemo/collections/asr/losses/ssl_losses/contrastive.py +++ b/nemo/collections/asr/losses/ssl_losses/contrastive.py @@ -201,8 +201,17 @@ def forward(self, spectrograms, spec_masks, decoder_outputs, decoder_lengths=Non targets.transpose(0, 1), targets_masked_only.size(0), # TxBxC # T' ) else: - # only sample from masked steps in utterance negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0)) # T'xBxC # T' + # if targets_masked_only.size(0) >= self.num_negatives: + # # only sample from masked steps in utterance + # negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0)) # T'xBxC # T' + # else: # for shorter samples (<8s) + # # print(f"sampling from non-masked ({self.num_negatives},{targets_masked_only.size(0)})") + # # sample from all steps in utterance + # negatives, _ = self.sample_negatives( + # targets.transpose(0, 1), targets_masked_only.size(0), # TxBxC # T' + # ) + # NxT'xBxC out_masked_only = out_masked_only.reshape(-1, out_masked_only.shape[-1]) diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 1ccc2d0ac..a933124ab 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -384,6 +384,7 @@ def rnnt_decoder_predictions_tensor( encoded_lengths: torch.Tensor, return_hypotheses: bool = False, partial_hypotheses: Optional[List[Hypothesis]] = None, + lang_ids: List[str] = None, ) -> Tuple[List[str], Optional[List[List[str]]], Optional[Union[Hypothesis, NBestHypotheses]]]: """ Decode an encoder output by autoregressive decoding of the Decoder+Joint networks. @@ -408,9 +409,10 @@ def rnnt_decoder_predictions_tensor( Look at rnnt_utils.NBestHypotheses for more information. """ # Compute hypotheses + # print("Decode strategy:", self.cfg.strategy) with torch.inference_mode(): hypotheses_list = self.decoding( - encoder_output=encoder_output, encoded_lengths=encoded_lengths, partial_hypotheses=partial_hypotheses + encoder_output=encoder_output, encoded_lengths=encoded_lengths, partial_hypotheses=partial_hypotheses, language_ids=lang_ids, ) # type: [List[Hypothesis]] # extract the hypotheses @@ -424,7 +426,7 @@ def rnnt_decoder_predictions_tensor( for nbest_hyp in prediction_list: # type: NBestHypotheses n_hyps = nbest_hyp.n_best_hypotheses # Extract all hypotheses for this sample - decoded_hyps = self.decode_hypothesis(n_hyps) # type: List[str] + decoded_hyps = self.decode_hypothesis(n_hyps, lang_ids) # type: List[str] # If computing timestamps if self.compute_timestamps is True: @@ -443,7 +445,7 @@ def rnnt_decoder_predictions_tensor( return best_hyp_text, all_hyp_text else: - hypotheses = self.decode_hypothesis(prediction_list) # type: List[str] + hypotheses = self.decode_hypothesis(prediction_list, lang_ids) # type: List[str] # If computing timestamps if self.compute_timestamps is True: @@ -462,7 +464,7 @@ def rnnt_decoder_predictions_tensor( best_hyp_text = [h.text for h in hypotheses] return best_hyp_text, None - def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hypothesis, NBestHypotheses]]: + def decode_hypothesis(self, hypotheses_list: List[Hypothesis], lang_ids: List[str] = None) -> List[Union[Hypothesis, NBestHypotheses]]: """ Decode a list of hypotheses into a list of strings. @@ -498,7 +500,10 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp token_repetitions = [1] * len(alignments) # preserve number of repetitions per token hypothesis = (prediction, alignments, token_repetitions) else: - hypothesis = self.decode_tokens_to_str(prediction) + if lang_ids is not None: + hypothesis = self.decode_tokens_to_str(prediction, lang_ids[ind]) + else: + hypothesis = self.decode_tokens_to_str(prediction) # TODO: remove # collapse leading spaces before . , ? for PC models diff --git a/nemo/collections/asr/metrics/rnnt_wer_bpe.py b/nemo/collections/asr/metrics/rnnt_wer_bpe.py index 99c71daeb..5606d083d 100644 --- a/nemo/collections/asr/metrics/rnnt_wer_bpe.py +++ b/nemo/collections/asr/metrics/rnnt_wer_bpe.py @@ -195,8 +195,9 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): tokenizer: The tokenizer which will be used for decoding. """ - def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec): - blank_id = tokenizer.tokenizer.vocab_size + def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec, blank_id=None): + if blank_id is None: + blank_id = tokenizer.tokenizer.vocab_size self.tokenizer = tokenizer super(RNNTBPEDecoding, self).__init__( @@ -222,7 +223,7 @@ def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]: hypothesis.words, hypothesis.token_confidence, hypothesis.y_sequence ) - def decode_tokens_to_str(self, tokens: List[int]) -> str: + def decode_tokens_to_str(self, tokens: List[int], lang: str = None) -> str: """ Implemented by subclass in order to decoder a token list into a string. @@ -232,7 +233,10 @@ def decode_tokens_to_str(self, tokens: List[int]) -> str: Returns: A decoded string. """ - hypothesis = self.tokenizer.ids_to_text(tokens) + if lang is not None: + hypothesis = self.tokenizer.ids_to_text(tokens, lang) + else: + hypothesis = self.tokenizer.ids_to_text(tokens) return hypothesis def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: @@ -275,7 +279,7 @@ def decode_ids_to_langs(self, tokens: List[int]) -> List[str]: lang_list = self.tokenizer.ids_to_text_and_langs(tokens) return lang_list - def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hypothesis, NBestHypotheses]]: + def decode_hypothesis(self, hypotheses_list: List[Hypothesis], lang_ids: List[str] = None) -> List[Union[Hypothesis, NBestHypotheses]]: """ Decode a list of hypotheses into a list of strings. Overrides the super() method optionally adding lang information @@ -286,7 +290,7 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp Returns: A list of strings. """ - hypotheses = super().decode_hypothesis(hypotheses_list) + hypotheses = super().decode_hypothesis(hypotheses_list, lang_ids) if self.compute_langs: if isinstance(self.tokenizer, AggregateTokenizer): for ind in range(len(hypotheses_list)): @@ -371,6 +375,7 @@ def update( encoded_lengths: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, + lang_ids: List[str] = None, ) -> torch.Tensor: words = 0 scores = 0 @@ -385,10 +390,13 @@ def update( for ind in range(targets_cpu_tensor.shape[0]): tgt_len = tgt_lenths_cpu_tensor[ind].item() target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist() - reference = self.decoding.decode_tokens_to_str(target) + if lang_ids is not None: + reference = self.decoding.decode_tokens_to_str(target, lang_ids[ind]) + else: + reference = self.decoding.decode_tokens_to_str(target) references.append(reference) - hypotheses, _ = self.decoding.rnnt_decoder_predictions_tensor(encoder_output, encoded_lengths) + hypotheses, _ = self.decoding.rnnt_decoder_predictions_tensor(encoder_output, encoded_lengths, lang_ids=lang_ids) if self.log_prediction: logging.info(f"\n") diff --git a/nemo/collections/asr/metrics/wer.py b/nemo/collections/asr/metrics/wer.py index 7f7f853d3..f1bc77f0a 100644 --- a/nemo/collections/asr/metrics/wer.py +++ b/nemo/collections/asr/metrics/wer.py @@ -378,6 +378,7 @@ def ctc_decoder_predictions_tensor( decoder_lengths: torch.Tensor = None, fold_consecutive: bool = True, return_hypotheses: bool = False, + lang_ids: List[str] = None, ) -> Tuple[List[str], Optional[List[List[str]]], Optional[Union[Hypothesis, NBestHypotheses]]]: """ Decodes a sequence of labels to words @@ -432,7 +433,7 @@ def ctc_decoder_predictions_tensor( for nbest_hyp in hypotheses_list: # type: NBestHypotheses n_hyps = nbest_hyp.n_best_hypotheses # Extract all hypotheses for this sample decoded_hyps = self.decode_hypothesis( - n_hyps, fold_consecutive + n_hyps, fold_consecutive, lang_ids ) # type: List[Union[Hypothesis, NBestHypotheses]] # If computing timestamps @@ -453,7 +454,7 @@ def ctc_decoder_predictions_tensor( else: hypotheses = self.decode_hypothesis( - hypotheses_list, fold_consecutive + hypotheses_list, fold_consecutive, lang_ids ) # type: List[Union[Hypothesis, NBestHypotheses]] # If computing timestamps @@ -476,7 +477,7 @@ def ctc_decoder_predictions_tensor( return best_hyp_text, None def decode_hypothesis( - self, hypotheses_list: List[Hypothesis], fold_consecutive: bool + self, hypotheses_list: List[Hypothesis], fold_consecutive: bool, lang_ids: List[str] = None, ) -> List[Union[Hypothesis, NBestHypotheses]]: """ Decode a list of hypotheses into a list of strings. @@ -541,8 +542,11 @@ def decode_hypothesis( # in order to compute exact time stamps. hypothesis = (decoded_prediction, token_lengths, token_repetitions) else: - hypothesis = self.decode_tokens_to_str(decoded_prediction) - + if lang_ids is not None: + hypothesis = self.decode_tokens_to_str(decoded_prediction, lang_ids[ind]) + else: + hypothesis = self.decode_tokens_to_str(decoded_prediction) + # TODO: remove # collapse leading spaces before . , ? for PC models hypothesis = re.sub(r'(\s+)([\.\,\?])', r'\2', hypothesis) @@ -1026,9 +1030,10 @@ class CTCDecoding(AbstractCTCDecoding): """ def __init__( - self, decoding_cfg, vocabulary, + self, decoding_cfg, vocabulary, blank_id = None ): - blank_id = len(vocabulary) + if blank_id is None: + blank_id = len(vocabulary) self.vocabulary = vocabulary self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) diff --git a/nemo/collections/asr/metrics/wer_bpe.py b/nemo/collections/asr/metrics/wer_bpe.py index 762acf172..3e3ee1923 100644 --- a/nemo/collections/asr/metrics/wer_bpe.py +++ b/nemo/collections/asr/metrics/wer_bpe.py @@ -138,8 +138,10 @@ class CTCBPEDecoding(AbstractCTCDecoding): tokenizer: NeMo tokenizer object, which inherits from TokenizerSpec. """ - def __init__(self, decoding_cfg, tokenizer: TokenizerSpec): - blank_id = tokenizer.tokenizer.vocab_size + def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None): + + if blank_id is None: + blank_id = tokenizer.tokenizer.vocab_size self.tokenizer = tokenizer super().__init__(decoding_cfg=decoding_cfg, blank_id=blank_id) @@ -175,7 +177,7 @@ def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]: self.decode_tokens_to_str(hypothesis.text[0]).split(), hypothesis.token_confidence, hypothesis.text[0] ) - def decode_tokens_to_str(self, tokens: List[int]) -> str: + def decode_tokens_to_str(self, tokens: List[int], lang: str = None) -> str: """ Implemented by subclass in order to decoder a token list into a string. @@ -185,7 +187,10 @@ def decode_tokens_to_str(self, tokens: List[int]) -> str: Returns: A decoded string. """ - hypothesis = self.tokenizer.ids_to_text(tokens) + if lang is not None: + hypothesis = self.tokenizer.ids_to_text(tokens, lang) + else: + hypothesis = self.tokenizer.ids_to_text(tokens) return hypothesis def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: @@ -263,6 +268,7 @@ def update( predictions: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, + lang_ids: List[str] = None, predictions_lengths: torch.Tensor = None, ): """ @@ -286,12 +292,20 @@ def update( for ind in range(targets_cpu_tensor.shape[0]): tgt_len = tgt_lenths_cpu_tensor[ind].item() target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist() - reference = self.decoding.decode_tokens_to_str(target) + if lang_ids is not None: + reference = self.decoding.decode_tokens_to_str(target, lang_ids[ind]) + else: + reference = self.decoding.decode_tokens_to_str(target) references.append(reference) - hypotheses, _ = self.decoding.ctc_decoder_predictions_tensor( - predictions, predictions_lengths, fold_consecutive=self.fold_consecutive - ) + if lang_ids is not None: + hypotheses, _ = self.decoding.ctc_decoder_predictions_tensor( + predictions, predictions_lengths, fold_consecutive=self.fold_consecutive, lang_ids=lang_ids + ) + else: + hypotheses, _ = self.decoding.ctc_decoder_predictions_tensor( + predictions, predictions_lengths, fold_consecutive=self.fold_consecutive + ) if self.log_prediction: logging.info(f"\n") diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index a74c7f3de..14614a1b7 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -51,11 +51,11 @@ def __init__(self, cfg: DictConfig, trainer=None): # Set the new vocabulary with open_dict(cfg): # sidestepping the potential overlapping tokens issue in aggregate tokenizers - if self.tokenizer_type == "agg": + if self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual": cfg.decoder.vocabulary = ListConfig(vocabulary) else: cfg.decoder.vocabulary = ListConfig(list(vocabulary.keys())) - + # Override number of classes if placeholder provided num_classes = cfg.decoder["num_classes"] @@ -68,18 +68,33 @@ def __init__(self, cfg: DictConfig, trainer=None): cfg.decoder["num_classes"] = len(vocabulary) super().__init__(cfg=cfg, trainer=trainer) - + + # Multisoftmax + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in cfg.decoder: + logging.info("Creating masks for multi-softmax layer.") + self.language_masks = {} + for language in self.tokenizer.tokenizers_dict.keys(): + self.language_masks[language] = [(token_language == language) for _, token_language in self.tokenizer.langs_by_token_id.items()] + self.language_masks[language].append(True) # Insert blank token + self.loss = CTCLoss( + num_classes=self.decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()), + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + self.decoder.language_masks = self.language_masks + # Setup decoding objects decoding_cfg = self.cfg.get('decoding', None) - # In case decoding config not found, use default config if decoding_cfg is None: decoding_cfg = OmegaConf.structured(CTCBPEDecodingConfig) with open_dict(self.cfg): self.cfg.decoding = decoding_cfg - - self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer) - + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in cfg.decoder: + self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys())) + else: + self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer) + # Setup metric with decoding strategy self._wer = WERBPE( decoding=self.decoding, @@ -105,7 +120,6 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): # DALI Dataset implements dataloader interface return dataset - shuffle = config['shuffle'] if config.get('is_tarred', False): shuffle = False @@ -118,15 +132,25 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): # support datasets that are lists of lists collate_fn = dataset.datasets[0].datasets[0].collate_fn - return torch.utils.data.DataLoader( - dataset=dataset, - batch_size=config['batch_size'], - collate_fn=collate_fn, - drop_last=config.get('drop_last', False), - shuffle=shuffle, - num_workers=config.get('num_workers', 0), - pin_memory=config.get('pin_memory', False), - ) + if config.get('shuffle', False): + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=config['shuffle'], + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + else: + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': """ diff --git a/nemo/collections/asr/models/ctc_bpe_multisoftmax_models.py b/nemo/collections/asr/models/ctc_bpe_multisoftmax_models.py new file mode 100644 index 000000000..6cd608c42 --- /dev/null +++ b/nemo/collections/asr/models/ctc_bpe_multisoftmax_models.py @@ -0,0 +1,916 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +from typing import Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.wer_bpe import WERBPE, CTCBPEDecoding, CTCBPEDecodingConfig +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.parts.mixins import ASRBPEMixin +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.utils import logging, model_utils +from nemo.core.classes.mixins import AccessMixin +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType, StringType + +__all__ = ['EncDecCTCModelBPE'] + + +class EncDecCTCMultiSoftmaxModelBPE(EncDecCTCModel, ASRBPEMixin): + """Encoder decoder CTC-based models with Byte Pair Encoding.""" + + def __init__(self, cfg: DictConfig, trainer=None): + # Convert to Hydra 1.0 compatible DictConfig + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + + if 'tokenizer' not in cfg: + raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") + + # Setup the tokenizer + self._setup_tokenizer(cfg.tokenizer) + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + with open_dict(cfg): + # sidestepping the potential overlapping tokens issue in aggregate tokenizers + if self.tokenizer_type == "agg": + cfg.decoder.vocabulary = ListConfig(vocabulary) + else: + cfg.decoder.vocabulary = ListConfig(list(vocabulary.keys())) + + # Override number of classes if placeholder provided + num_classes = cfg.decoder["num_classes"] + + if num_classes < 1: + logging.info( + "\nReplacing placeholder number of classes ({}) with actual number of classes - {}".format( + num_classes, len(vocabulary) + ) + ) + cfg.decoder["num_classes"] = len(vocabulary) + + super().__init__(cfg=cfg, trainer=trainer) + + self.loss = CTCLoss( + num_classes=self.decoder._num_classes, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + + # # Multisoftmax + if self.tokenizer_type == "agg" and "multisoftmax" in cfg.decoder: + logging.info("Creating masks for multi-softmax layer.") + self.language_masks = {} + for language in self.tokenizer.tokenizers_dict.keys(): + self.language_masks[language] = [(token_language == language) for _, token_language in self.tokenizer.langs_by_token_id.items()] + self.language_masks[language].append(True) # Insert blank token + self.decoder.language_masks = self.language_masks + + # Setup decoding objects + decoding_cfg = self.cfg.get('decoding', None) + + # In case decoding config not found, use default config + if decoding_cfg is None: + decoding_cfg = OmegaConf.structured(CTCBPEDecodingConfig) + with open_dict(self.cfg): + self.cfg.decoding = decoding_cfg + + + self.decoding = {} + for language in self.decoder.languages: + self.decoding[language] = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer.tokenizers_dict[language]) + + self._wer_dict = {} + for language in self.decoder.languages: + self._wer_dict[language] = WERBPE( + decoding=self.decoding[language], + use_cer=self._cfg.get('use_cer', False), + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( + config=config, + local_rank=self.local_rank, + global_rank=self.global_rank, + world_size=self.world_size, + tokenizer=self.tokenizer, + preprocessor_cfg=self.cfg.get("preprocessor", None), + ) + + if dataset is None: + return None + + if isinstance(dataset, AudioToBPEDALIDataset): + # DALI Dataset implements dataloader interface + return dataset + + if config.get('is_tarred', False): + shuffle = False + + if hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + elif hasattr(dataset.datasets[0], 'collate_fn'): + # support datasets that are lists of entries + collate_fn = dataset.datasets[0].collate_fn + else: + # support datasets that are lists of lists + collate_fn = dataset.datasets[0].datasets[0].collate_fn + + if config.get('shuffle', False): + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + else: + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + num_workers: (int) number of workers. Depends of the batch_size and machine. \ + 0 - only the main process will load batches, 1 - one worker (not main process) + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + + if 'manifest_filepath' in config: + manifest_filepath = config['manifest_filepath'] + batch_size = config['batch_size'] + else: + manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json') + batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + + dl_config = { + 'manifest_filepath': manifest_filepath, + 'sample_rate': self.preprocessor._sample_rate, + 'batch_size': batch_size, + 'shuffle': False, + 'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)), + 'pin_memory': True, + 'channel_selector': config.get('channel_selector', None), + 'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False), + } + + if config.get("augmentor"): + dl_config['augmentor'] = config.get("augmentor") + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer + + # PTL-specific methods + def training_step(self, batch, batch_nb): + # Reset access registry + if AccessMixin.is_access_enabled(): + AccessMixin.reset_registry(self) + + if self.is_interctc_enabled(): + AccessMixin.set_access_enabled(access_enabled=True) + + if "multisoftmax" not in self.cfg.decoder: + signal, signal_len, transcript, transcript_len = batch + language = None + else: + signal, signal_len, transcript, transcript_len, sample_ids, language_ids = batch + assert all(i == language_ids[0] for i in language_ids), f"Language ids are different for a batch -> {language_ids}" + language = language_ids[0] + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + log_probs, encoded_len, predictions = self.forward( + processed_signal=signal, processed_signal_length=signal_len + ) + else: + if "multisoftmax" in self.cfg.decoder: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len, language_ids=language_ids) + else: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) + + if hasattr(self, '_trainer') and self._trainer is not None: + log_every_n_steps = self._trainer.log_every_n_steps + else: + log_every_n_steps = 1 + + loss_value = self.loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + + # Add auxiliary losses, if registered + loss_value = self.add_auxiliary_losses(loss_value) + # only computing WER when requested in the logs (same as done for final-layer WER below) + loss_value, tensorboard_logs = self.add_interctc_losses( + loss_value, transcript, transcript_len, compute_wer=((batch_nb + 1) % log_every_n_steps == 0) + ) + + # Reset access registry + if AccessMixin.is_access_enabled(): + AccessMixin.reset_registry(self) + + tensorboard_logs.update( + { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + ) + + if (batch_nb + 1) % log_every_n_steps == 0: + self._wer_dict[language].update( + predictions=log_probs, + targets=transcript, + target_lengths=transcript_len, + predictions_lengths=encoded_len, + ) + wer, _, _ = self._wer_dict[language].compute() + self._wer_dict[language].reset() + tensorboard_logs.update({'training_batch_wer': wer}) + + return {'loss': loss_value, 'log': tensorboard_logs} + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + if "multisoftmax" not in self.cfg.decoder: + signal, signal_len, transcript, transcript_len = batch + language = None + else: + signal, signal_len, transcript, transcript_len, sample_ids, language_ids = batch + assert all(i == language_ids[0] for i in language_ids), f"Language ids are different for a batch -> {language_ids}" + language = language_ids[0] + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + log_probs, encoded_len, predictions = self.forward( + processed_signal=signal, processed_signal_length=signal_len + ) + else: + if "multisoftmax" in self.cfg.decoder: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len, language_ids=language_ids) + else: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) + + transcribed_texts, _ = self._wer_dict[language].decoding.ctc_decoder_predictions_tensor( + decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False, + ) + + sample_id = sample_id.cpu().detach().numpy() + return list(zip(sample_id, transcribed_texts)) + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + if self.is_interctc_enabled(): + AccessMixin.set_access_enabled(access_enabled=True) + + if "multisoftmax" not in self.cfg.decoder: + signal, signal_len, transcript, transcript_len = batch + language = None + else: + signal, signal_len, transcript, transcript_len, sample_ids, language_ids = batch + assert all(i == language_ids[0] for i in language_ids), f"Language ids are different for a batch -> {language_ids}" + language = language_ids[0] + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + log_probs, encoded_len, predictions = self.forward( + processed_signal=signal, processed_signal_length=signal_len + ) + else: + if "multisoftmax" in self.cfg.decoder: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len, language_ids=language_ids) + else: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) + + loss_value = self.loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + loss_value, metrics = self.add_interctc_losses( + loss_value, transcript, transcript_len, compute_wer=True, log_wer_num_denom=True, log_prefix="val_", + ) + + self._wer_dict[language].update( + predictions=log_probs, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len + ) + wer, wer_num, wer_denom = self._wer_dict[language].compute() + self._wer_dict[language].reset() + metrics.update({'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom, 'val_wer': wer}) + + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + # Reset access registry + if AccessMixin.is_access_enabled(): + AccessMixin.reset_registry(self) + + return metrics + + @torch.no_grad() + def transcribe( + self, + paths2audio_files: List[str], + language: str, + batch_size: int = 4, + logprobs: bool = False, + return_hypotheses: bool = False, + num_workers: int = 0, + channel_selector: Optional[ChannelSelectorType] = None, + augmentor: DictConfig = None, + verbose: bool = True, + ) -> List[str]: + """ + If modify this function, please remember update transcribe_partial_audio() in + nemo/collections/asr/parts/utils/trancribe_utils.py + + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + logprobs: (bool) pass True to get log probabilities instead of transcripts. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + num_workers: (int) number of workers for DataLoader + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. + augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + verbose: (bool) whether to display tqdm progress bar + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if return_hypotheses and logprobs: + raise ValueError( + "Either `return_hypotheses` or `logprobs` can be True at any given time." + "Returned hypotheses will contain the logprobs." + ) + + if num_workers is None: + num_workers = min(batch_size, os.cpu_count() - 1) + + # We will store transcriptions here + hypotheses = [] + all_hypotheses = [] + + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + dither_value = self.preprocessor.featurizer.dither + pad_to_value = self.preprocessor.featurizer.pad_to + + try: + self.preprocessor.featurizer.dither = 0.0 + self.preprocessor.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + self.encoder.freeze() + self.decoder.freeze() + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w', encoding='utf-8') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': ''} + fp.write(json.dumps(entry) + '\n') + + config = { + 'paths2audio_files': paths2audio_files, + 'batch_size': batch_size, + 'temp_dir': tmpdir, + 'num_workers': num_workers, + 'channel_selector': channel_selector, + } + + if augmentor: + config['augmentor'] = augmentor + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose): + logits, logits_len, greedy_predictions = self.forward( + input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + ) + + if logprobs: + # dump log probs per file + for idx in range(logits.shape[0]): + lg = logits[idx][: logits_len[idx]] + hypotheses.append(lg.cpu().numpy()) + else: + current_hypotheses, all_hyp = self.decoding[language].ctc_decoder_predictions_tensor( + logits, decoder_lengths=logits_len, return_hypotheses=return_hypotheses, + ) + logits = logits.cpu() + + if return_hypotheses: + # dump log probs per file + for idx in range(logits.shape[0]): + current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]] + if current_hypotheses[idx].alignments is None: + current_hypotheses[idx].alignments = current_hypotheses[idx].y_sequence + + if all_hyp is None: + hypotheses += current_hypotheses + else: + hypotheses += all_hyp + + del greedy_predictions + del logits + del test_batch + finally: + # set mode back to its original value + self.train(mode=mode) + self.preprocessor.featurizer.dither = dither_value + self.preprocessor.featurizer.pad_to = pad_to_value + if mode is True: + self.encoder.unfreeze() + self.decoder.unfreeze() + logging.set_verbosity(logging_level) + + return hypotheses + + + def change_vocabulary( + self, + new_tokenizer_dir: Union[str, DictConfig], + new_tokenizer_type: str, + decoding_cfg: Optional[DictConfig] = None, + ): + """ + Changes vocabulary of the tokenizer used during CTC decoding process. + Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + Args: + new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) + new_tokenizer_type: Either `agg`, `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers, + whereas `wpe` is used for `BertTokenizer`. + new_tokenizer_cfg: A config for the new tokenizer. if provided, pre-empts the dir and type + + Returns: None + + """ + if isinstance(new_tokenizer_dir, DictConfig): + if new_tokenizer_type == 'agg': + new_tokenizer_cfg = new_tokenizer_dir + else: + raise ValueError( + f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}' + ) + else: + new_tokenizer_cfg = None + + if new_tokenizer_cfg is not None: + tokenizer_cfg = new_tokenizer_cfg + else: + if not os.path.isdir(new_tokenizer_dir): + raise NotADirectoryError( + f'New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}' + f"New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}" + ) + + if new_tokenizer_type.lower() not in ('bpe', 'wpe'): + raise ValueError(f'New tokenizer type must be either `bpe` or `wpe`') + + tokenizer_cfg = OmegaConf.create({'dir': new_tokenizer_dir, 'type': new_tokenizer_type}) + + # Setup the tokenizer + self._setup_tokenizer(tokenizer_cfg) + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + decoder_config = copy.deepcopy(self.decoder.to_config_dict()) + # sidestepping the potential overlapping tokens issue in aggregate tokenizers + if self.tokenizer_type == "agg": + decoder_config.vocabulary = ListConfig(vocabulary) + else: + decoder_config.vocabulary = ListConfig(list(vocabulary.keys())) + + decoder_num_classes = decoder_config['num_classes'] + + # Override number of classes if placeholder provided + logging.info( + "\nReplacing old number of classes ({}) with new number of classes - {}".format( + decoder_num_classes, len(vocabulary) + ) + ) + + decoder_config['num_classes'] = len(vocabulary) + + del self.decoder + self.decoder = EncDecCTCModelBPE.from_config_dict(decoder_config) + del self.loss + self.loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + + if decoding_cfg is None: + # Assume same decoding config as before + decoding_cfg = self.cfg.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(CTCBPEDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + self.decoding = {} + for language in self.decoder.languages: + self.decoding[language] = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer.tokenizers_dict[language]) + + self._wer_dict = {} + for language in self.decoder.languages: + self._wer_dict[language] = WERBPE( + decoding=self.decoding[language], + use_cer=self._cfg.get('use_cer', False), + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + # Update config + with open_dict(self.cfg.decoder): + self._cfg.decoder = decoder_config + + with open_dict(self.cfg.decoding): + self._cfg.decoding = decoding_cfg + + logging.info(f"Changed tokenizer to {self.decoder.vocabulary} vocabulary.") + + def change_decoding_strategy(self, decoding_cfg: DictConfig): + """ + Changes decoding strategy used during CTC decoding process. + + Args: + decoding_cfg: A config for the decoder, which is optional. If the decoding type + needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + """ + if decoding_cfg is None: + # Assume same decoding config as before + logging.info("No `decoding_cfg` passed when changing decoding strategy, using internal config") + decoding_cfg = self.cfg.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(CTCBPEDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + self.decoding = {} + for language in self.decoder.languages: + self.decoding[language] = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer.tokenizers_dict[language]) + + self._wer_dict = {} + for language in self.decoder.languages: + self._wer_dict[language] = WERBPE( + decoding=self.decoding[language], + use_cer=self._cfg.get('use_cer', False), + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + self.decoder.temperature = decoding_cfg.get('temperature', 1.0) + + # Update config + with open_dict(self.cfg.decoding): + self.cfg.decoding = decoding_cfg + + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + results = [] + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_citrinet_256", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_256", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_256/versions/1.0.0rc1/files/stt_en_citrinet_256.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_citrinet_512", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_512", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_512/versions/1.0.0rc1/files/stt_en_citrinet_512.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_citrinet_1024", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_1024", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_1024/versions/1.0.0rc1/files/stt_en_citrinet_1024.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_citrinet_256_gamma_0_25", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_256_gamma_0_25", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_256_gamma_0_25/versions/1.0.0/files/stt_en_citrinet_256_gamma_0_25.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_citrinet_512_gamma_0_25", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_512_gamma_0_25", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_512_gamma_0_25/versions/1.0.0/files/stt_en_citrinet_512_gamma_0_25.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_citrinet_1024_gamma_0_25", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_1024_gamma_0_25", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_1024_gamma_0_25/versions/1.0.0/files/stt_en_citrinet_1024_gamma_0_25.nemo", + ) + + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_es_citrinet_512", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_citrinet_512", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_citrinet_512/versions/1.0.0/files/stt_es_citrinet_512.nemo", + ) + + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_de_citrinet_1024", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_citrinet_1024", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_de_citrinet_1024/versions/1.5.0/files/stt_de_citrinet_1024.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_fr_citrinet_1024_gamma_0_25", + description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_fr_citrinet_1024_gamma_0_25", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_fr_citrinet_1024_gamma_0_25/versions/1.5/files/stt_fr_citrinet_1024_gamma_0_25.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_fr_no_hyphen_citrinet_1024_gamma_0_25", + description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_fr_citrinet_1024_gamma_0_25", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_fr_citrinet_1024_gamma_0_25/versions/1.5/files/stt_fr_no_hyphen_citrinet_1024_gamma_0_25.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_es_citrinet_1024_gamma_0_25", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_citrinet_1024_gamma_0_25", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_citrinet_1024_gamma_0_25/versions/1.8.0/files/stt_es_citrinet_1024_gamma_0_25.nemo", + ) + + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_conformer_ctc_small", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_small", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_ctc_small/versions/1.6.0/files/stt_en_conformer_ctc_small.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_conformer_ctc_medium", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_medium", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_ctc_medium/versions/1.6.0/files/stt_en_conformer_ctc_medium.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_conformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_ctc_large/versions/1.10.0/files/stt_en_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_conformer_ctc_xlarge", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_xlarge", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_ctc_xlarge/versions/1.10.0/files/stt_en_conformer_ctc_xlarge.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_squeezeformer_ctc_xsmall_ls", + description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_squeezeformer_ctc_xsmall_ls", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_squeezeformer_ctc_xsmall_ls/versions/1.13.0/files/stt_en_squeezeformer_ctc_xsmall_ls.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_squeezeformer_ctc_small_ls", + description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_squeezeformer_ctc_small_ls", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_squeezeformer_ctc_small_ls/versions/1.13.0/files/stt_en_squeezeformer_ctc_small_ls.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_squeezeformer_ctc_small_medium_ls", + description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_squeezeformer_ctc_small_medium_ls", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_squeezeformer_ctc_small_medium_ls/versions/1.13.0/files/stt_en_squeezeformer_ctc_small_medium_ls.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_squeezeformer_ctc_medium_ls", + description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_squeezeformer_ctc_medium_ls", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_squeezeformer_ctc_medium_ls/versions/1.13.0/files/stt_en_squeezeformer_ctc_medium_ls.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_squeezeformer_ctc_medium_large_ls", + description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_squeezeformer_ctc_medium_large_ls", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_squeezeformer_ctc_medium_large_ls/versions/1.13.0/files/stt_en_squeezeformer_ctc_medium_large_ls.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_squeezeformer_ctc_large_ls", + description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_squeezeformer_ctc_large_ls", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_squeezeformer_ctc_large_ls/versions/1.13.0/files/stt_en_squeezeformer_ctc_large_ls.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_conformer_ctc_small_ls", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_small_ls", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_ctc_small_ls/versions/1.0.0/files/stt_en_conformer_ctc_small_ls.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_conformer_ctc_medium_ls", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_medium_ls", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_ctc_medium_ls/versions/1.0.0/files/stt_en_conformer_ctc_medium_ls.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_conformer_ctc_large_ls", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_large_ls", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_ctc_large_ls/versions/1.0.0/files/stt_en_conformer_ctc_large_ls.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_fr_conformer_ctc_large", + description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_fr_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_fr_conformer_ctc_large/versions/1.5.1/files/stt_fr_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_fr_no_hyphen_conformer_ctc_large", + description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_fr_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_fr_conformer_ctc_large/versions/1.5.1/files/stt_fr_no_hyphen_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_de_conformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_de_conformer_ctc_large/versions/1.5.0/files/stt_de_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_es_conformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_conformer_ctc_large/versions/1.8.0/files/stt_es_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_hi_conformer_ctc_medium", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_hi_conformer_ctc_medium", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_hi_conformer_ctc_medium/versions/1.6.0/files/stt_hi_conformer_ctc_medium.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_mr_conformer_ctc_medium", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_mr_conformer_ctc_medium", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_mr_conformer_ctc_medium/versions/1.6.0/files/stt_mr_conformer_ctc_medium.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_enes_conformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_enes_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_enes_conformer_ctc_large/versions/1.0.0/files/stt_enes_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_ca_conformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ca_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ca_conformer_ctc_large/versions/1.11.0/files/stt_ca_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_rw_conformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_rw_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_rw_conformer_ctc_large/versions/1.11.0/files/stt_rw_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_enes_conformer_ctc_large_codesw", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_enes_conformer_ctc_large_codesw", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_enes_conformer_ctc_large_codesw/versions/1.0.0/files/stt_enes_conformer_ctc_large_codesw.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_be_conformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_be_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_be_conformer_ctc_large/versions/1.12.0/files/stt_be_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_hr_conformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_hr_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_hr_conformer_ctc_large/versions/1.11.0/files/stt_hr_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_it_conformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_it_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_it_conformer_ctc_large/versions/1.13.0/files/stt_it_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_ru_conformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ru_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ru_conformer_ctc_large/versions/1.13.0/files/stt_ru_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_eo_conformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_eo_conformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_eo_conformer_ctc_large/versions/1.14.0/files/stt_eo_conformer_ctc_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_fastconformer_ctc_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_ctc_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_fastconformer_ctc_large/versions/1.0.0/files/stt_en_fastconformer_ctc_large.nemo", + ) + results.append(model) + + return results diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 1446e1ce8..7d7b10b94 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -32,7 +32,7 @@ from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.classes.mixins import AccessMixin -from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType, StringType from nemo.utils import logging __all__ = ['EncDecCTCModel'] @@ -69,7 +69,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): cfg.decoder["num_classes"] = len(self.cfg.decoder.vocabulary) self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder) - + self.loss = CTCLoss( num_classes=self.decoder.num_classes_with_blank - 1, zero_infinity=True, @@ -99,6 +99,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): dist_sync_on_step=True, log_prediction=self._cfg.get("log_prediction", False), ) + + self.language_masks = None # Only supported for CTC_BPE models # Setup optional Optimization flags self.setup_optimization_flags() @@ -281,7 +283,7 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di decoding_cls = OmegaConf.structured(CTCDecodingConfig) decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) - + self.decoding = CTCDecoding( decoding_cfg=decoding_cfg, vocabulary=OmegaConf.to_container(self.decoder.vocabulary) ) @@ -349,6 +351,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): # Automatically inject args from model config to dataloader config audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate') audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels') + print("CONFIG:", config.return_language_id) dataset = audio_to_text_dataset.get_audio_to_text_char_dataset_from_config( config=config, local_rank=self.local_rank, @@ -377,16 +380,26 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): # support datasets that are lists of lists collate_fn = dataset.datasets[0].datasets[0].collate_fn - return torch.utils.data.DataLoader( - dataset=dataset, - batch_size=config['batch_size'], - collate_fn=collate_fn, - drop_last=config.get('drop_last', False), - shuffle=shuffle, - num_workers=config.get('num_workers', 0), - pin_memory=config.get('pin_memory', False), - ) - + if config.get('shuffle', False): + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=config['shuffle'], + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + else: + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): """ Sets up the training data loader via a Dict-like object. @@ -402,8 +415,8 @@ def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` """ - if 'shuffle' not in train_data_config: - train_data_config['shuffle'] = True + # if 'shuffle' not in train_data_config: + # train_data_config['shuffle'] = True # preserve config self._update_dataset_config(dataset_name='train', config=train_data_config) @@ -486,6 +499,7 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "sample_id": NeuralType(tuple('B'), LengthsType(), optional=True), + 'language_ids': [NeuralType(('B'), StringType(), optional=True)], } @property @@ -498,7 +512,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: @typecheck() def forward( - self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None + self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, language_ids=None ): """ Forward pass of the model. @@ -539,7 +553,7 @@ def forward( encoder_output = self.encoder(audio_signal=processed_signal, length=processed_signal_length) encoded = encoder_output[0] encoded_len = encoder_output[1] - log_probs = self.decoder(encoder_output=encoded) + log_probs = self.decoder(encoder_output=encoded, language_ids=language_ids) greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) return ( @@ -557,19 +571,26 @@ def training_step(self, batch, batch_nb): if self.is_interctc_enabled(): AccessMixin.set_access_enabled(access_enabled=True) - signal, signal_len, transcript, transcript_len = batch + if "multisoftmax" not in self.cfg.decoder: + signal, signal_len, transcript, transcript_len = batch + else: + signal, signal_len, transcript, transcript_len, sample_ids, language_ids = batch + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: log_probs, encoded_len, predictions = self.forward( processed_signal=signal, processed_signal_length=signal_len ) else: - log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) + if "multisoftmax" in self.cfg.decoder: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len, language_ids=language_ids) + else: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) if hasattr(self, '_trainer') and self._trainer is not None: log_every_n_steps = self._trainer.log_every_n_steps else: log_every_n_steps = 1 - + loss_value = self.loss( log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len ) @@ -594,12 +615,21 @@ def training_step(self, batch, batch_nb): ) if (batch_nb + 1) % log_every_n_steps == 0: - self._wer.update( - predictions=log_probs, - targets=transcript, - target_lengths=transcript_len, - predictions_lengths=encoded_len, - ) + if "multisoftmax" in self.cfg.decoder: + self._wer.update( + predictions=log_probs, + targets=transcript, + target_lengths=transcript_len, + predictions_lengths=encoded_len, + lang_ids=language_ids, + ) + else: + self._wer.update( + predictions=log_probs, + targets=transcript, + target_lengths=transcript_len, + predictions_lengths=encoded_len, + ) wer, _, _ = self._wer.compute() self._wer.reset() tensorboard_logs.update({'training_batch_wer': wer}) @@ -607,17 +637,30 @@ def training_step(self, batch, batch_nb): return {'loss': loss_value, 'log': tensorboard_logs} def predict_step(self, batch, batch_idx, dataloader_idx=0): - signal, signal_len, transcript, transcript_len, sample_id = batch + if "multisoftmax" not in self.cfg.decoder: + signal, signal_len, transcript, transcript_len = batch + else: + signal, signal_len, transcript, transcript_len, sample_ids, language_ids = batch + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: log_probs, encoded_len, predictions = self.forward( processed_signal=signal, processed_signal_length=signal_len ) + transcribed_texts, _ = self._wer.decoding.ctc_decoder_predictions_tensor( + decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False, + ) else: - log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) - - transcribed_texts, _ = self._wer.decoding.ctc_decoder_predictions_tensor( - decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False, - ) + if "multisoftmax" in self.cfg.decoder: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len, language_ids=language_ids) + transcribed_texts, _ = self._wer.decoding.ctc_decoder_predictions_tensor( + decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False, lang_ids=language_ids, + ) + else: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) + transcribed_texts, _ = self._wer.decoding.ctc_decoder_predictions_tensor( + decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False, + ) + sample_id = sample_id.cpu().detach().numpy() return list(zip(sample_id, transcribed_texts)) @@ -626,13 +669,20 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): if self.is_interctc_enabled(): AccessMixin.set_access_enabled(access_enabled=True) - signal, signal_len, transcript, transcript_len = batch + if "multisoftmax" not in self.cfg.decoder: + signal, signal_len, transcript, transcript_len = batch + else: + signal, signal_len, transcript, transcript_len, sample_ids, language_ids = batch + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: log_probs, encoded_len, predictions = self.forward( processed_signal=signal, processed_signal_length=signal_len ) else: - log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) + if "multisoftmax" in self.cfg.decoder: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len, language_ids=language_ids) + else: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) loss_value = self.loss( log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len @@ -640,10 +690,14 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): loss_value, metrics = self.add_interctc_losses( loss_value, transcript, transcript_len, compute_wer=True, log_wer_num_denom=True, log_prefix="val_", ) - - self._wer.update( - predictions=log_probs, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len - ) + if "multisoftmax" in self.cfg.decoder: + self._wer.update( + predictions=log_probs, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, lang_ids=language_ids, + ) + else: + self._wer.update( + predictions=log_probs, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, + ) wer, wer_num, wer_denom = self._wer.compute() self._wer.reset() metrics.update({'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom, 'val_wer': wer}) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index b88669a1f..feafcbc30 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -73,7 +73,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): ) with open_dict(cfg): - if self.tokenizer_type == "agg": + if self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual": cfg.aux_ctc.decoder.vocabulary = ListConfig(vocabulary) else: cfg.aux_ctc.decoder.vocabulary = ListConfig(list(vocabulary.keys())) @@ -92,6 +92,40 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.decoding = RNNTBPEDecoding( decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, ) + + # Multisoftmax + self.language_masks = None + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in cfg.decoder: + logging.info("Creating masks for multi-softmax layer.") + self.language_masks = {} + self.token_id_offsets = self.tokenizer.token_id_offset + self.offset_token_ids_by_token_id = self.tokenizer.offset_token_ids_by_token_id + for language in self.tokenizer.tokenizers_dict.keys(): + self.language_masks[language] = [(token_language == language) for _, token_language in self.tokenizer.langs_by_token_id.items()] + self.language_masks[language].append(True) # Insert blank token + self.ctc_loss = CTCLoss( + num_classes=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()), + zero_infinity=True, + reduction=self.cfg.aux_ctc.get("ctc_reduction", "mean_batch"), + ) + # Setup RNNT Loss + loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None)) + self.loss = RNNTLoss( + num_classes=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()), + loss_name=loss_name, + loss_kwargs=loss_kwargs, + reduction=self.cfg.get("rnnt_reduction", "mean_batch"), + ) + # Setup decoding object + self.decoding = RNNTBPEDecoding( + decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()) + ) + + self.decoder.language_masks = self.language_masks + self.joint.language_masks = self.language_masks + self.joint.token_id_offsets = self.token_id_offsets + self.joint.offset_token_ids_by_token_id = self.offset_token_ids_by_token_id + self.ctc_decoder.language_masks = self.language_masks # Setup wer object self.wer = RNNTBPEWER( @@ -113,8 +147,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): ctc_decoding_cfg = OmegaConf.structured(CTCBPEDecodingConfig) with open_dict(self.cfg.aux_ctc): self.cfg.aux_ctc.decoding = ctc_decoding_cfg - self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer) - + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in cfg.decoder: + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys())) + else: + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer) + # Setup CTC WER self.ctc_wer = WERBPE( decoding=self.ctc_decoding, @@ -124,7 +161,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): ) # setting the RNNT decoder as the default one - self.cur_decoder = "rnnt" + # self.cur_decoder = "rnnt" + self.cur_decoder = "ctc" def _setup_dataloader_from_config(self, config: Optional[Dict]): dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( @@ -143,7 +181,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): # DALI Dataset implements dataloader interface return dataset - shuffle = config['shuffle'] + # shuffle = config['shuffle'] if config.get('is_tarred', False): shuffle = False @@ -156,15 +194,25 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): # support datasets that are lists of lists collate_fn = dataset.datasets[0].datasets[0].collate_fn - return torch.utils.data.DataLoader( - dataset=dataset, - batch_size=config['batch_size'], - collate_fn=collate_fn, - drop_last=config.get('drop_last', False), - shuffle=shuffle, - num_workers=config.get('num_workers', 0), - pin_memory=config.get('pin_memory', False), - ) + if config.get('shuffle', False): + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=config['shuffle'], + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + else: + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': """ diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 5ca6124ec..6f6bfcb77 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -87,6 +87,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # setting the RNNT decoder as the default one self.cur_decoder = "rnnt" + # self.cur_decoder = "ctc" # setting up interCTC loss (from InterCTCMixin) self.setup_interctc(decoder_name='ctc_decoder', loss_name='ctc_loss', wer_name='ctc_wer') @@ -374,7 +375,11 @@ def training_step(self, batch, batch_nb): if self.is_interctc_enabled(): AccessMixin.set_access_enabled(access_enabled=True) - signal, signal_len, transcript, transcript_len = batch + if "multisoftmax" not in self.cfg.decoder: + signal, signal_len, transcript, transcript_len = batch + language_ids = None + else: + signal, signal_len, transcript, transcript_len, sample_ids, language_ids = batch # forward() only performs encoder forward if isinstance(batch, DALIOutputs) and batch.has_processed_signal: @@ -401,7 +406,7 @@ def training_step(self, batch, batch_nb): # If fused Joint-Loss-WER is not used if not self.joint.fuse_loss_wer: # Compute full joint and loss - joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder, language_ids=language_ids) loss_value = self.loss( log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length ) @@ -429,6 +434,7 @@ def training_step(self, batch, batch_nb): transcripts=transcript, transcript_lengths=transcript_len, compute_wer=compute_wer, + language_ids=language_ids ) # Add auxiliary losses, if registered @@ -443,7 +449,7 @@ def training_step(self, batch, batch_nb): tensorboard_logs.update({'training_batch_wer': wer}) if self.ctc_loss_weight > 0: - log_probs = self.ctc_decoder(encoder_output=encoded) + log_probs = self.ctc_decoder(encoder_output=encoded, language_ids=language_ids) ctc_loss = self.ctc_loss( log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len ) @@ -451,12 +457,21 @@ def training_step(self, batch, batch_nb): tensorboard_logs['train_ctc_loss'] = ctc_loss loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss if compute_wer: - self.ctc_wer.update( - predictions=log_probs, - targets=transcript, - target_lengths=transcript_len, - predictions_lengths=encoded_len, - ) + if "multisoftmax" in self.cfg.decoder: + self.ctc_wer.update( + predictions=log_probs, + targets=transcript, + target_lengths=transcript_len, + predictions_lengths=encoded_len, + lang_ids=language_ids, + ) + else: + self.ctc_wer.update( + predictions=log_probs, + targets=transcript, + target_lengths=transcript_len, + predictions_lengths=encoded_len, + ) ctc_wer, _, _ = self.ctc_wer.compute() self.ctc_wer.reset() tensorboard_logs.update({'training_batch_wer_ctc': ctc_wer}) @@ -486,19 +501,30 @@ def training_step(self, batch, batch_nb): def predict_step(self, batch, batch_idx, dataloader_idx=0): # TODO: add support for CTC decoding - signal, signal_len, transcript, transcript_len, sample_id = batch + if "multisoftmax" not in self.cfg.decoder: + signal, signal_len, transcript, transcript_len = batch + language_ids = None + else: + signal, signal_len, transcript, transcript_len, sample_ids, language_ids = batch # forward() only performs encoder forward if isinstance(batch, DALIOutputs) and batch.has_processed_signal: encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) + best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False + ) else: encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + if "multisoftmax" in self.cfg.decoder: + best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False, lang_ids=language_ids + ) + else: + best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False + ) del signal - best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( - encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False - ) - sample_id = sample_id.cpu().detach().numpy() return list(zip(sample_id, best_hyp_text)) @@ -506,7 +532,11 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): if self.is_interctc_enabled(): AccessMixin.set_access_enabled(access_enabled=True) - signal, signal_len, transcript, transcript_len = batch + if "multisoftmax" not in self.cfg.decoder: + signal, signal_len, transcript, transcript_len = batch + language_ids=None + else: + signal, signal_len, transcript, transcript_len, sample_ids, language_ids = batch # forward() only performs encoder forward if isinstance(batch, DALIOutputs) and batch.has_processed_signal: @@ -522,14 +552,14 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): if not self.joint.fuse_loss_wer: if self.compute_eval_loss: decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) - joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder, language_ids=language_ids) loss_value = self.loss( log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length ) tensorboard_logs['val_loss'] = loss_value - self.wer.update(encoded, encoded_len, transcript, transcript_len) + self.wer.update(encoded, encoded_len, transcript, transcript_len, lang_ids=language_ids) wer, wer_num, wer_denom = self.wer.compute() self.wer.reset() @@ -555,6 +585,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): transcripts=transcript, transcript_lengths=target_len, compute_wer=compute_wer, + language_ids=language_ids ) if loss_value is not None: tensorboard_logs['val_loss'] = loss_value @@ -563,7 +594,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): tensorboard_logs['val_wer_denom'] = wer_denom tensorboard_logs['val_wer'] = wer - log_probs = self.ctc_decoder(encoder_output=encoded) + log_probs = self.ctc_decoder(encoder_output=encoded, language_ids=language_ids) if self.compute_eval_loss: ctc_loss = self.ctc_loss( log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len @@ -573,7 +604,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss tensorboard_logs['val_loss'] = loss_value self.ctc_wer.update( - predictions=log_probs, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, + predictions=log_probs, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, lang_ids=language_ids ) ctc_wer, ctc_wer_num, ctc_wer_denom = self.ctc_wer.compute() self.ctc_wer.reset() diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index 6fed8be9d..59cb6d4fb 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -479,7 +479,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): # DALI Dataset implements dataloader interface return dataset - shuffle = config['shuffle'] + # shuffle = config['shuffle'] if config.get('is_tarred', False): shuffle = False @@ -492,15 +492,25 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): # support datasets that are lists of lists collate_fn = dataset.datasets[0].datasets[0].collate_fn - return torch.utils.data.DataLoader( - dataset=dataset, - batch_size=config['batch_size'], - collate_fn=collate_fn, - drop_last=config.get('drop_last', False), - shuffle=shuffle, - num_workers=config.get('num_workers', 0), - pin_memory=config.get('pin_memory', False), - ) + if config.get('shuffle', False): + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=config['shuffle'], + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + else: + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': """ diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 84e086358..30d66df5e 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -482,16 +482,26 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): # support datasets that are lists of lists collate_fn = dataset.datasets[0].datasets[0].collate_fn - return torch.utils.data.DataLoader( - dataset=dataset, - batch_size=config['batch_size'], - collate_fn=collate_fn, - drop_last=config.get('drop_last', False), - shuffle=shuffle, - num_workers=config.get('num_workers', 0), - pin_memory=config.get('pin_memory', False), - ) - + if config.get('shuffle', False): + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=config['shuffle'], + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + else: + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): """ Sets up the training data loader via a Dict-like object. @@ -507,8 +517,8 @@ def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` """ - if 'shuffle' not in train_data_config: - train_data_config['shuffle'] = True + # if 'shuffle' not in train_data_config: + # train_data_config['shuffle'] = True # preserve config self._update_dataset_config(dataset_name='train', config=train_data_config) @@ -602,7 +612,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: @typecheck() def forward( - self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None + self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, language_ids=None ): """ Forward pass of the model. Note that for RNNT Models, the forward pass of the model is a 3 step process, diff --git a/nemo/collections/asr/modules/__init__.py b/nemo/collections/asr/modules/__init__.py index ecd430b56..e8a6f34be 100644 --- a/nemo/collections/asr/modules/__init__.py +++ b/nemo/collections/asr/modules/__init__.py @@ -34,6 +34,7 @@ ParallelConvASREncoder, SpeakerDecoder, ) +from nemo.collections.asr.modules.multi_conv_asr import MultiConvASRDecoder from nemo.collections.asr.modules.graph_decoder import ViterbiDecoderWithGraph from nemo.collections.asr.modules.hybrid_autoregressive_transducer import HATJoint from nemo.collections.asr.modules.lstm_decoder import LSTMDecoder diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index a05ee894f..a203dbd59 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -47,6 +47,7 @@ LogprobsType, NeuralType, SpectrogramType, + StringType, ) from nemo.utils import logging @@ -409,13 +410,16 @@ class ConvASRDecoder(NeuralModule, Exportable, adapter_mixins.AdapterModuleMixin @property def input_types(self): - return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())}) + return OrderedDict({ + "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + 'language_ids': [NeuralType(('B'), StringType(), optional=True)], + }) @property def output_types(self): return OrderedDict({"logprobs": NeuralType(('B', 'T', 'D'), LogprobsType())}) - def __init__(self, feat_in, num_classes, init_mode="xavier_uniform", vocabulary=None): + def __init__(self, feat_in, num_classes, init_mode="xavier_uniform", vocabulary=None, multisoftmax=False, language_masks=None): super().__init__() if vocabulary is None and num_classes < 0: @@ -447,20 +451,63 @@ def __init__(self, feat_in, num_classes, init_mode="xavier_uniform", vocabulary= # to change, requires running ``model.temperature = T`` explicitly self.temperature = 1.0 + + self.multisoftmax = multisoftmax + self.language_masks = language_masks + def masked_softmax(self, x, mask=None): + """ + Performs masked softmax, as simply masking post-softmax can be + inaccurate + :param x: [batch_size, num_items] + :param mask: [batch_size, num_items] + :return: + """ + if mask is not None: + mask = mask.float() + if mask is not None: + x_masked = x * mask + (1 - 1 / mask) + else: + x_masked = x + # print(x_masked[0][0]) + x_max = x_masked.max(-1)[0] + x_exp = (x - x_max.unsqueeze(-1)).exp() + if mask is not None: + x_exp = x_exp * mask.float() + # return (x - x_max.unsqueeze(-1)) / torch.log(x_exp.sum(-1).unsqueeze(-1)) + return x_exp / x_exp.sum(-1).unsqueeze(-1) + @typecheck() - def forward(self, encoder_output): + def forward(self, encoder_output, language_ids=None): # Adapter module forward step if self.is_adapter_available(): encoder_output = encoder_output.transpose(1, 2) # [B, T, C] encoder_output = self.forward_enabled_adapters(encoder_output) encoder_output = encoder_output.transpose(1, 2) # [B, C, T] - + if self.temperature != 1.0: - return torch.nn.functional.log_softmax( - self.decoder_layers(encoder_output).transpose(1, 2) / self.temperature, dim=-1 - ) - return torch.nn.functional.log_softmax(self.decoder_layers(encoder_output).transpose(1, 2), dim=-1) + decoder_output = self.decoder_layers(encoder_output).transpose(1, 2) / self.temperature + else: + decoder_output = self.decoder_layers(encoder_output).transpose(1, 2) + + if language_ids is not None: + sample_mask = [] + for lang_idx in language_ids: + sample_mask.append(self.language_masks[lang_idx]) + sample_mask = torch.tensor(sample_mask, dtype=torch.bool) # .to(decoder_output.device) + # Repeat across timesteps [B, T, C] + sample_mask = sample_mask.unsqueeze(1) + mask = sample_mask.repeat(1, decoder_output.shape[1], 1) + # Send mask to GPU + mask = mask.to(decoder_output.device) + # masked_output = self.masked_softmax(decoder_output, mask) # B x T x 3073 -> B x T x 257 + decoder_output = torch.masked_select(decoder_output, mask).view(decoder_output.shape[0],decoder_output.shape[1],-1) + else: + masked_output = None + # print(mask[0][0]) + # softmax_output = self.masked_softmax(decoder_output, mask) + # return softmax_output + return torch.nn.functional.log_softmax(decoder_output, dim=-1) def input_example(self, max_batch=1, max_dim=256): """ @@ -582,7 +629,7 @@ def __init__( padding=kernel_size // 2, ) ) - self.decoder_layers.append(nn.Conv1d(self.feat_hidden, self.feat_hidden, kernel_size=1, bias=True)) + self.decoder_layers.append(nn.Conv1d(self.feat_hidden, self.feat_hidden, œ=1, bias=True)) self.decoder_layers.append(nn.BatchNorm1d(self.feat_hidden, eps=1e-3, momentum=0.1)) self.decoder_layers.append(activation) diff --git a/nemo/collections/asr/modules/multi_conv_asr.py b/nemo/collections/asr/modules/multi_conv_asr.py new file mode 100644 index 000000000..1f0178888 --- /dev/null +++ b/nemo/collections/asr/modules/multi_conv_asr.py @@ -0,0 +1,170 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import List, Optional, Set, Union + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf + +from nemo.collections.asr.parts.submodules.jasper import ( + JasperBlock, + MaskedConv1d, + ParallelBlock, + SqueezeExcite, + init_weights, + jasper_activations, +) +from nemo.collections.asr.parts.submodules.tdnn_attention import ( + AttentivePoolLayer, + StatsPoolLayer, + TDNNModule, + TDNNSEModule, +) +from nemo.collections.asr.parts.utils import adapter_utils +from nemo.core.classes.common import typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.mixins import AccessMixin, adapter_mixins +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types import ( + AcousticEncodedRepresentation, + LengthsType, + LogitsType, + LogprobsType, + NeuralType, + SpectrogramType, + StringType, +) +from nemo.utils import logging + +__all__ = ['MultiConvASRDecoder'] + + +class MultiConvASRDecoder(NeuralModule, Exportable, adapter_mixins.AdapterModuleMixin): + """Simple ASR Decoder for use with CTC-based models such as JasperNet and QuartzNet + + Based on these papers: + https://arxiv.org/pdf/1904.03288.pdf + https://arxiv.org/pdf/1910.10261.pdf + https://arxiv.org/pdf/2005.04290.pdf + """ + + @property + def input_types(self): + return OrderedDict({ + "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + 'language_ids': [NeuralType(('B'), StringType(), optional=True)], + }) + + @property + def output_types(self): + return OrderedDict({"logprobs": NeuralType(('B', 'T', 'D'), LogprobsType())}) + + def __init__(self, feat_in, languages, num_classes_per_lang, init_mode="xavier_uniform", num_classes=None, vocabulary=None, multisoftmax=True): + super().__init__() + + # if vocabulary is None and num_classes < 0: + # raise ValueError( + # f"Neither of the vocabulary and num_classes are set! At least one of them need to be set." + # ) + + # if num_classes <= 0: + # num_classes = len(vocabulary) + # logging.info(f"num_classes of ConvASRDecoder is set to the size of the vocabulary: {num_classes}.") + + if vocabulary is not None: + # if num_classes != len(vocabulary): + # raise ValueError( + # f"If vocabulary is specified, it's length should be equal to the num_classes. Instead got: num_classes={num_classes} and len(vocabulary)={len(vocabulary)}" + # ) + self.__vocabulary = vocabulary + self._feat_in = feat_in + # Add 1 for blank char + self._num_classes_per_lang = [] + self.languages = languages + for num_classes in num_classes_per_lang: + self._num_classes_per_lang.append(num_classes + 1) + self._num_classes = self._num_classes_per_lang[0] + + + self.decoder_layers = {} + for lang, num_classes in zip(self.languages, self._num_classes_per_lang): + self.decoder_layers[lang] = torch.nn.Sequential( + torch.nn.Conv1d(self._feat_in, num_classes, kernel_size=1, bias=True) + ) + self.decoder_layers = torch.nn.ModuleDict(self.decoder_layers) + self.apply(lambda x: init_weights(x, mode=init_mode)) + + accepted_adapters = [adapter_utils.LINEAR_ADAPTER_CLASSPATH] + self.set_accepted_adapter_types(accepted_adapters) + + # to change, requires running ``model.temperature = T`` explicitly + self.temperature = 1.0 + + @typecheck() + def forward(self, encoder_output, language_ids): + # Adapter module forward step + if self.is_adapter_available(): + encoder_output = encoder_output.transpose(1, 2) # [B, T, C] + encoder_output = self.forward_enabled_adapters(encoder_output) + encoder_output = encoder_output.transpose(1, 2) # [B, C, T] + + language = language_ids[0] + if self.temperature != 1.0: + decoder_output = self.decoder_layers[language](encoder_output).transpose(1, 2) / self.temperature + else: + decoder_output = self.decoder_layers[language](encoder_output).transpose(1, 2) + + return torch.nn.functional.log_softmax(decoder_output, dim=-1) + + def input_example(self, max_batch=1, max_dim=256): + """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ + input_example = torch.randn(max_batch, self._feat_in, max_dim).to(next(self.parameters()).device) + return tuple([input_example]) + + def _prepare_for_export(self, **kwargs): + m_count = 0 + for m in self.modules(): + if type(m).__name__ == "MaskedConv1d": + m.use_mask = False + m_count += 1 + if m_count > 0: + logging.warning(f"Turned off {m_count} masked convolutions") + Exportable._prepare_for_export(self, **kwargs) + + # Adapter method overrides + def add_adapter(self, name: str, cfg: DictConfig): + # Update the config with correct input dim + cfg = self._update_adapter_cfg_input_dim(cfg) + # Add the adapter + super().add_adapter(name=name, cfg=cfg) + + def _update_adapter_cfg_input_dim(self, cfg: DictConfig): + cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self._feat_in) + return cfg + + @property + def vocabulary(self): + return self.__vocabulary + + @property + def num_classes_with_blank(self): + return self._num_classes diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 04bdd25ac..c34c21dd1 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -48,6 +48,7 @@ LossType, NeuralType, SpectrogramType, + StringType ) from nemo.utils import logging @@ -572,6 +573,9 @@ def __init__( normalization_mode: Optional[str] = None, random_state_sampling: bool = False, blank_as_pad: bool = True, + multisoftmax=False, + language_masks=None, + ): # Required arguments self.pred_hidden = prednet['pred_hidden'] @@ -602,6 +606,9 @@ def __init__( rnn_hidden_size=prednet.get("rnn_hidden_size", -1), ) self._rnnt_export = False + + self.multisoftmax = multisoftmax + self.language_masks = language_masks @typecheck() def forward(self, targets, target_length, states=None): @@ -976,7 +983,7 @@ def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List Returns: (tuple): decoder states for given id - ([L x (1, H)], [L x (1, H)]) + ([L x (1, H)], [L x (1, H)]s) """ if batch_states is not None: state_list = [] @@ -1130,6 +1137,7 @@ def input_types(self): "transcripts": NeuralType(('B', 'T'), LabelsType(), optional=True), "transcript_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), "compute_wer": NeuralType(optional=True), + 'language_ids': [NeuralType(('B'), StringType(), optional=True)], } @property @@ -1181,6 +1189,11 @@ def __init__( fuse_loss_wer: bool = False, fused_batch_size: Optional[int] = None, experimental_fuse_loss_wer: Any = None, + language_masks=None, + multilingual: bool = False, + language_keys: Optional[List] = None, + token_id_offsets=None, + offset_token_ids_by_token_id=None, ): super().__init__() @@ -1189,6 +1202,11 @@ def __init__( self._vocab_size = num_classes self._num_extra_outputs = num_extra_outputs self._num_classes = num_classes + 1 + num_extra_outputs # 1 is for blank + self.language_masks = language_masks + self.token_id_offsets = token_id_offsets + self.offset_token_ids_by_token_id = offset_token_ids_by_token_id + self.multilingual = multilingual + self.language_keys = language_keys if experimental_fuse_loss_wer is not None: # Override fuse_loss_wer from deprecated argument @@ -1247,6 +1265,7 @@ def forward( transcripts: Optional[torch.Tensor] = None, transcript_lengths: Optional[torch.Tensor] = None, compute_wer: bool = False, + language_ids=None, ) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]: # encoder = (B, D, T) # decoder = (B, D, U) if passed, else None @@ -1262,7 +1281,24 @@ def forward( "decoder_outputs can only be None for fused step!" ) - out = self.joint(encoder_outputs, decoder_outputs) # [B, T, U, V + 1] + out = self.joint(encoder_outputs, decoder_outputs, language_ids=language_ids) # [B, T, U, V + 1] + + # if language_ids is not None: + # sample_mask = [] + # for lang_idx in language_ids: + # sample_mask.append(self.language_masks[lang_idx]) + # sample_mask = torch.tensor(sample_mask, dtype=torch.bool) # .to(decoder_output.device) + # # Repeat across timesteps [B, T, U, V + 1] + # sample_mask = sample_mask.unsqueeze(1) + # mask = sample_mask.repeat(1, out.shape[1], 1) + # sample_mask = sample_mask.unsqueeze(2) + # mask = sample_mask.repeat(1, 1, out.shape[2], 1) + # # Send mask to GPU + # mask = mask.to(out.device) + # # masked_output = self.masked_softmax(decoder_output, mask) + # # print("Before mask", sub_joint.shape) + # out = torch.masked_select(out, mask).view(out.shape[0],out.shape[1],out.shape[2],-1) + return out else: @@ -1318,8 +1354,27 @@ def forward( sub_dec = sub_dec.narrow(dim=1, start=0, length=int(max_sub_transcript_length + 1)) # Perform joint => [sub-batch, T', U', V + 1] - sub_joint = self.joint(sub_enc, sub_dec) - + if language_ids is not None: + sub_joint = self.joint(sub_enc, sub_dec, language_ids=language_ids[begin:end]) + else: + sub_joint = self.joint(sub_enc, sub_dec) + + # if language_ids is not None: + # sample_mask = [] + # for lang_idx in language_ids[begin:end]: + # sample_mask.append(self.language_masks[lang_idx]) + # sample_mask = torch.tensor(sample_mask, dtype=torch.bool) # .to(decoder_output.device) + # # Repeat across timesteps [sub-batch, T, U, V + 1] + # sample_mask = sample_mask.unsqueeze(1) + # mask = sample_mask.repeat(1, sub_joint.shape[1], 1) + # sample_mask = sample_mask.unsqueeze(2) + # mask = sample_mask.repeat(1, 1, sub_joint.shape[2], 1) + # # Send mask to GPU + # mask = mask.to(sub_joint.device) + # # masked_output = self.masked_softmax(decoder_output, mask) + # # print("Before mask", sub_joint.shape) + # sub_joint = torch.masked_select(sub_joint, mask).view(sub_joint.shape[0],sub_joint.shape[1],sub_joint.shape[2],-1) + # print("After mask", sub_joint.shape) del sub_dec # Reduce transcript length to correct alignment @@ -1357,7 +1412,10 @@ def forward( sub_transcripts = sub_transcripts.detach() # Update WER on each process without syncing - self.wer.update(sub_enc, sub_enc_lens, sub_transcripts, sub_transcript_lens) + if language_ids is not None: + self.wer.update(sub_enc, sub_enc_lens, sub_transcripts, sub_transcript_lens, lang_ids=language_ids[begin:end]) + else: + self.wer.update(sub_enc, sub_enc_lens, sub_transcripts, sub_transcript_lens) del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens @@ -1377,7 +1435,7 @@ def forward( return losses, wer, wer_num, wer_denom - def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: + def joint(self, f: torch.Tensor, g: torch.Tensor, language_ids=None) -> torch.Tensor: """ Compute the joint step of the network. @@ -1422,8 +1480,20 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: # Forward adapter modules on joint hidden if self.is_adapter_available(): inp = self.forward_enabled_adapters(inp) - - res = self.joint_net(inp) # [B, T, U, V + 1] + + # res = self.joint_net(inp) # [B, T, U, V + 1] + + if language_ids is not None: + + # Do partial forward of joint net (skipping the final linear) + for module in self.joint_net[:-1]: + inp = module(inp) # [B, T, U, H] + res_single = [] + for single_inp, lang in zip(inp, language_ids): + res_single.append(self.joint_net[-1][lang](single_inp)) + res = torch.stack(res_single) + else: + res = self.joint_net(inp) # [B, T, U, V + 1] del inp @@ -1473,11 +1543,22 @@ def _joint_net_modules(self, num_classes, pred_n_hidden, enc_n_hidden, joint_n_h elif activation == 'tanh': activation = torch.nn.Tanh() - layers = ( - [activation] - + ([torch.nn.Dropout(p=dropout)] if dropout else []) - + [torch.nn.Linear(joint_n_hidden, num_classes)] - ) + if self.multilingual: + final_layer = torch.nn.ModuleDict() + logging.info(f"Vocab size for each language: {self._vocab_size // len(self.language_keys)}") + for lang in self.language_keys: + final_layer[lang] = torch.nn.Linear(joint_n_hidden, (self._vocab_size // len(self.language_keys)+1)) + layers = ( + [activation] + + ([torch.nn.Dropout(p=dropout)] if dropout else []) + + [final_layer] + ) + else: + layers = ( + [activation] + + ([torch.nn.Dropout(p=dropout)] if dropout else []) + + [torch.nn.Linear(joint_n_hidden, num_classes)] + ) return pred, enc, torch.nn.Sequential(*layers) # Adapter method overrides @@ -1688,6 +1769,9 @@ def __init__( preserve_memory: bool = False, fuse_loss_wer: bool = False, fused_batch_size: Optional[int] = None, + language_masks=None, + token_id_offsets=None, + offset_token_ids_by_token_id=None, ): super().__init__( jointnet=jointnet, @@ -1697,6 +1781,9 @@ def __init__( preserve_memory=preserve_memory, fuse_loss_wer=fuse_loss_wer, fused_batch_size=fused_batch_size, + language_masks=language_masks, + token_id_offsets=token_id_offsets, + offset_token_ids_by_token_id=offset_token_ids_by_token_id, ) self.n_samples = n_samples self.register_buffer('blank_id', torch.tensor([self.num_classes_with_blank - 1]), persistent=False) @@ -1710,6 +1797,7 @@ def forward( transcripts: Optional[torch.Tensor] = None, transcript_lengths: Optional[torch.Tensor] = None, compute_wer: bool = False, + language_ids=None, ) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]: # If in inference mode, revert to basic RNNT Joint behaviour. # Sampled RNNT is only used for training. @@ -1722,6 +1810,7 @@ def forward( transcripts=transcripts, transcript_lengths=transcript_lengths, compute_wer=compute_wer, + language_ids=language_ids, ) if transcripts is None or transcript_lengths is None: @@ -1799,7 +1888,7 @@ def forward( # Perform sampled joint => [sub-batch, T', U', {V' < V} + 1}] sub_joint, sub_transcripts_remapped = self.sampled_joint( - sub_enc, sub_dec, transcript=sub_transcripts, transcript_lengths=sub_transcript_lens + sub_enc, sub_dec, transcript=sub_transcripts, transcript_lengths=sub_transcript_lens, language_ids=language_ids[begin:end], ) del sub_dec @@ -1842,7 +1931,10 @@ def forward( sub_transcripts = sub_transcripts.detach() # Update WER on each process without syncing - self.wer.update(sub_enc, sub_enc_lens, sub_transcripts, sub_transcript_lens) + if language_ids is not None: + self.wer.update(sub_enc, sub_enc_lens, sub_transcripts, sub_transcript_lens, lang_ids=language_ids[begin:end]) + else: + self.wer.update(sub_enc, sub_enc_lens, sub_transcripts, sub_transcript_lens) del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens @@ -1863,7 +1955,7 @@ def forward( return losses, wer, wer_num, wer_denom def sampled_joint( - self, f: torch.Tensor, g: torch.Tensor, transcript: torch.Tensor, transcript_lengths: torch.Tensor, + self, f: torch.Tensor, g: torch.Tensor, transcript: torch.Tensor, transcript_lengths: torch.Tensor, language_ids=None, ) -> torch.Tensor: """ Compute the sampled joint step of the network. @@ -1932,7 +2024,19 @@ def sampled_joint( # Begin compute of sampled RNNT joint with torch.no_grad(): - # gather true labels + + if language_ids is not None: + transcript_with_offset = [] + for t, lang in zip(transcript, language_ids): + offset_transcript = [] + for t_token in t.tolist(): + if t_token != 0: + offset_transcript.append(t_token+self.token_id_offsets[lang]) + else: + offset_transcript.append(0) + transcript_with_offset.append(offset_transcript) + transcript = torch.tensor(transcript_with_offset, dtype=transcript.dtype, device=transcript.device) + transcript_vocab_ids = torch.unique(transcript) # augment with blank token id @@ -1966,6 +2070,10 @@ def sampled_joint( # new_transcript = [1, 0, 2, 3, 2, 0] index = torch.bucketize(transcript.ravel(), palette) transcript = key[index].reshape(transcript.shape) + if language_ids is not None: + # remap to original transcript ids which are without offsets for multi-softmax + new_transcript = [[self.offset_token_ids_by_token_id[idx.item()] for idx in t] for t in transcript] + transcript = torch.tensor(new_transcript, dtype=transcript.dtype) transcript = transcript.to(t_device) # Extract out partial weight tensor and bias tensor of just the V_Pos vocabulary from the full joint. @@ -2027,7 +2135,25 @@ def sampled_joint( # Finally, construct the sampled joint as the V_Sampled = Union(V_Pos, V_Neg) # Here, we simply concatenate the two tensors to construct the joint with V_Sampled vocab # because before we have properly asserted that Intersection(V_Pos, V_Neg) is a null set. + # print(transcript_scores.shape, noise_scores.shape) res = torch.cat([transcript_scores, noise_scores], dim=-1) + + # Multisoftmax language-wise sampling of output + if language_ids is not None: + sample_mask = [] + sampled_vocab = transcript_vocab_ids.tolist() + accept_samples.tolist() + for lang_idx in language_ids: + sample_mask.append([self.language_masks[lang_idx][v] for v in sampled_vocab]) + sample_mask = torch.tensor(sample_mask, dtype=torch.bool) # .to(decoder_output.device) + # Repeat across timesteps [B, T, U, V + 1] + sample_mask = sample_mask.unsqueeze(1) + sample_mask = sample_mask.repeat(1, res.shape[1], 1) + sample_mask = sample_mask.unsqueeze(2) + mask = sample_mask.repeat(1, 1, res.shape[2], 1) + # Send mask to GPU + mask = mask.to(res.device) + print(res.shape, mask.shape) + res = torch.masked_select(res, mask).view(res.shape[0],res.shape[1],res.shape[2],-1) del inp diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index eba896d04..e97c92d04 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -57,6 +57,8 @@ def _setup_tokenizer(self, tokenizer_cfg: DictConfig): raise ValueError("`tokenizer.type` cannot be None") elif tokenizer_type.lower() == 'agg': self._setup_aggregate_tokenizer(tokenizer_cfg) + elif tokenizer_type.lower() == 'multilingual': + self._setup_multilingual_tokenizer(tokenizer_cfg) else: self._setup_monolingual_tokenizer(tokenizer_cfg) @@ -215,6 +217,49 @@ def _setup_aggregate_tokenizer(self, tokenizer_cfg: DictConfig): ][lang]['type'] self.tokenizer = tokenizers.AggregateTokenizer(tokenizers_dict) + + def _setup_multilingual_tokenizer(self, tokenizer_cfg: DictConfig): + # Prevent tokenizer parallelism (unless user has explicitly set it) + if 'TOKENIZERS_PARALLELISM' not in os.environ: + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + self.tokenizer_cfg = OmegaConf.to_container(tokenizer_cfg, resolve=True) # type: dict + + # the aggregate tokenizer does not have one tokenizer_dir but multiple ones + self.tokenizer_dir = None + + self.tokenizer_cfg.pop('dir', None) # Remove tokenizer directory, if any + # Remove tokenizer_type -- obviously if we are here, the type is 'agg' + self.tokenizer_type = self.tokenizer_cfg.pop('type').lower() + + # the aggregate tokenizer should not have these + self.hf_tokenizer_kwargs = {} + self.tokenizer_cfg.pop("hf_kwargs", {}) # Remove HF tokenizer kwargs, if any + + logging.info('_setup_tokenizer: detected an aggregate tokenizer') + # need to de-register any monolingual config items if they exist + self._cleanup_monolingual_and_aggregate_config_and_artifacts_if_needed() + + # overwrite tokenizer type + if hasattr(self, 'cfg') and 'tokenizer' in self.cfg: + self.cfg.tokenizer.type = self.tokenizer_type + + tokenizers_dict = {} + # init each of the monolingual tokenizers found in the config and assemble into AggregateTokenizer + for lang, tokenizer_config in self.tokenizer_cfg[self.AGGREGATE_TOKENIZERS_DICT_PREFIX].items(): + (tokenizer, model_path, vocab_path, spe_vocab_path,) = self._make_tokenizer(tokenizer_config, lang) + + tokenizers_dict[lang] = tokenizer + if hasattr(self, 'cfg'): + with open_dict(self.cfg.tokenizer): + self.cfg.tokenizer[self.AGGREGATE_TOKENIZERS_DICT_PREFIX][lang]['dir'] = self.tokenizer_cfg[ + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + ][lang]['dir'] + self.cfg.tokenizer[self.AGGREGATE_TOKENIZERS_DICT_PREFIX][lang]['type'] = self.tokenizer_cfg[ + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + ][lang]['type'] + + self.tokenizer = tokenizers.MultilingualTokenizer(tokenizers_dict) def _make_tokenizer(self, tokenizer_cfg: DictConfig, lang=None): diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 5e98b03f2..4b055bc76 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -38,7 +38,7 @@ from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureMixin, ConfidenceMethodConfig from nemo.collections.common.parts.rnn import label_collate from nemo.core.classes import Typing, typecheck -from nemo.core.neural_types import AcousticEncodedRepresentation, ElementType, HypothesisType, LengthsType, NeuralType +from nemo.core.neural_types import AcousticEncodedRepresentation, ElementType, HypothesisType, LengthsType, NeuralType, StringType from nemo.utils import logging @@ -140,6 +140,7 @@ def input_types(self): "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_lengths": NeuralType(tuple('B'), LengthsType()), "partial_hypotheses": [NeuralType(elements_type=HypothesisType(), optional=True)], # must always be last + "language_ids": [NeuralType(('B'), StringType(), optional=True)], # must always be last } @property @@ -213,7 +214,7 @@ def _pred_step( # output: [B, 1, K] return self.decoder.predict(label, hidden, add_sos=add_sos, batch_size=batch_size) - def _joint_step(self, enc, pred, log_normalize: Optional[bool] = None): + def _joint_step(self, enc, pred, log_normalize: Optional[bool] = None, language_ids=None): """ Common joint step based on AbstractRNNTJoint implementation. @@ -226,8 +227,13 @@ def _joint_step(self, enc, pred, log_normalize: Optional[bool] = None): logits of shape (B, T=1, U=1, V + 1) """ with torch.no_grad(): - logits = self.joint.joint(enc, pred) - + ## Old + # logits = self.joint.joint(enc, pred) + ## New for multisoftmax + self.joint._fuse_loss_wer = False + logits = self.joint(encoder_outputs=enc.transpose(1, 2), decoder_outputs=pred.transpose(1, 2), language_ids=language_ids) + self.joint._fuse_loss_wer = True + if log_normalize is None: if not logits.is_cuda: # Use log softmax only if on CPU logits = logits.log_softmax(dim=len(logits.shape) - 1) @@ -569,6 +575,7 @@ def forward( encoder_output: torch.Tensor, encoded_lengths: torch.Tensor, partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, + language_ids=None, ): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. @@ -596,7 +603,7 @@ def forward( with self.decoder.as_frozen(), self.joint.as_frozen(): inseq = encoder_output # [B, T, D] hypotheses = self._greedy_decode( - inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses + inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses, language_ids=language_ids ) # Pack the hypotheses results @@ -613,6 +620,7 @@ def _greedy_decode_blank_as_pad( out_len: torch.Tensor, device: torch.device, partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, + language_ids=None, ): if partial_hypotheses is not None: raise NotImplementedError("`partial_hypotheses` support is not supported") @@ -682,7 +690,7 @@ def _greedy_decode_blank_as_pad( # Batched joint step - Output = [B, V + 1] # If preserving per-frame confidence, log_normalize must be true - logp = self._joint_step(f, g, log_normalize=True if self.preserve_frame_confidence else None)[ + logp = self._joint_step(f, g, log_normalize=True if self.preserve_frame_confidence else None, language_ids=language_ids)[ :, 0, 0, : ] @@ -818,6 +826,7 @@ def _greedy_decode_masked( out_len: torch.Tensor, device: torch.device, partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, + language_ids=None, ): if partial_hypotheses is not None: raise NotImplementedError("`partial_hypotheses` support is not supported") @@ -898,7 +907,7 @@ def _greedy_decode_masked( # Batched joint step - Output = [B, V + 1] # If preserving per-frame confidence, log_normalize must be true - logp = self._joint_step(f, g, log_normalize=True if self.preserve_frame_confidence else None)[ + logp = self._joint_step(f, g, log_normalize=True if self.preserve_frame_confidence else None, language_ids=language_ids)[ :, 0, 0, : ] diff --git a/nemo/collections/common/data/dataset.py b/nemo/collections/common/data/dataset.py index 030e99780..97c7fdd77 100644 --- a/nemo/collections/common/data/dataset.py +++ b/nemo/collections/common/data/dataset.py @@ -63,7 +63,6 @@ def __init__( self.world_size = world_size self.sampling_kwargs = {} self.sampling_scale = sampling_scale - if sampling_technique == 'temperature': self.index_generator = ConcatDataset.temperature_generator self.sampling_kwargs['temperature'] = sampling_temperature diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 4616f95e1..5c3c35990 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from tqdm import tqdm import collections import json import os @@ -138,9 +139,9 @@ def __init__( if index_by_file_id: self.mapping = {} - for id_, audio_file, duration, offset, text, speaker, orig_sr, token_labels, lang in zip( + for id_, audio_file, duration, offset, text, speaker, orig_sr, token_labels, lang in tqdm(zip( ids, audio_files, durations, offsets, texts, speakers, orig_sampling_rates, token_labels, langs - ): + )): # Duration filters. if min_duration is not None and duration < min_duration: duration_filtered += duration @@ -217,7 +218,7 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): [], ) speakers, orig_srs, token_labels, langs = [], [], [], [] - for item in manifest.item_iter(manifests_files): + for item in tqdm(manifest.item_iter(manifests_files)): ids.append(item['id']) audio_files.append(item['audio_file']) durations.append(item['duration']) diff --git a/nemo/collections/common/tokenizers/__init__.py b/nemo/collections/common/tokenizers/__init__.py index f46e3b150..57ff9fae2 100644 --- a/nemo/collections/common/tokenizers/__init__.py +++ b/nemo/collections/common/tokenizers/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer +from nemo.collections.common.tokenizers.multilingual_tokenizer import MultilingualTokenizer from nemo.collections.common.tokenizers.bytelevel_tokenizers import ByteLevelTokenizer from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer diff --git a/nemo/collections/common/tokenizers/multilingual_tokenizer.py b/nemo/collections/common/tokenizers/multilingual_tokenizer.py new file mode 100644 index 000000000..a26f9230e --- /dev/null +++ b/nemo/collections/common/tokenizers/multilingual_tokenizer.py @@ -0,0 +1,235 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Union + +import numpy as np + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.utils import logging + +__all__ = ['AggregateTokenizer'] + + +class DummyTokenizer: + def __init__(self, vocab): + self.vocab = vocab + self.vocab_size = len(vocab) + + # minimum compatibility + # since all the monolingual tokenizers have a vocab + # additional methods could be added here + def get_vocab(self): + return self.vocab + + +class MultilingualTokenizer(TokenizerSpec): + ''' + MultilingualTokenizer, allowing one to combine multiple regular monolongual tokenizers into one tokenizer. + The intuition is that we can use existing tokenizers "as is", without retraining, and associate each tokenizer with a language id + during text processing (language id will be used to route the incoming text sample to the right tokenizer) + as well as a token id range for detokenization (e.g. [0..127] for tokenizer A, [128..255] for tokenizer B) so + that the orignal text could be reconstructed. Note that we assume that the incoming dict of langs / tokenizers + is ordered, e.g. the first tokenizer will be assigned a lower interval of token ids + Args: + tokenizers: dict of tokenizers, keys are lang ids, values are actual tokenizers + ''' + + def __init__(self, tokenizers: Dict): + + self.tokenizers_dict = tokenizers + self.vocabulary = [] + + # the tokenizers should produce non-overlapping, ordered token ids + # keys are language ids + self.token_id_offset = {} + + # keys are tokenizer numbers + self.token_id_offset_by_tokenizer_num = {} + offset = 0 + i = 0 + for lang, tokenizer in self.tokenizers_dict.items(): + self.token_id_offset[lang] = offset + self.token_id_offset_by_tokenizer_num[i] = offset + offset += len(tokenizer.vocab) + i += 1 + + for tokenizer in self.tokenizers_dict.values(): + self.vocabulary.extend(tokenizer.vocab) + + self.vocab_size = len(self.vocabulary) + logging.info(f'Aggregate vocab size: {self.vocab_size}') + + # for compatibility purposes only -- right now only the get_vocab method + # is supported, returning the joint vocab across all tokenizers + self.tokenizer = DummyTokenizer(self.vocabulary) + + # lookup tables to speed up token to text operations + # if there are two tokenizers, [0,1], ['en', 'es'], each with 128 tokens, the aggregate tokenizer + # token range will be [0,255]. The below method provides three look up tables: + # one, to convert the incoming token id -- e.g. 200 into its real id (200-127 = 73) + # second, to compute the tokenizer id that should process that token (1) + # third, the compute the lang id for that token ('es') + offset_token_ids_by_token_id, tokenizers_by_token_id, langs_by_token_id = self._calculate_offsets() + + self.offset_token_ids_by_token_id = offset_token_ids_by_token_id + self.tokenizers_by_token_id = tokenizers_by_token_id + self.langs_by_token_id = langs_by_token_id + + def _calculate_offsets(self): + offsets = {} + tokenizers = {} + langs = {} + cur_num = 0 + tot = len(self.tokenizers_dict) + for id in range(len(self.vocabulary)): + off_id = id - list(self.token_id_offset.values())[cur_num] + if cur_num + 1 < tot: + if id >= list(self.token_id_offset.values())[cur_num + 1]: + cur_num += 1 + off_id = id - list(self.token_id_offset.values())[cur_num] + offsets[id] = off_id + tokenizers[id] = list(self.tokenizers_dict.values())[cur_num] + langs[id] = list(self.tokenizers_dict.keys())[cur_num] + + return offsets, tokenizers, langs + + def text_to_tokens(self, text, lang_id): + tokenizer = self.tokenizers_dict[lang_id] + return tokenizer.text_to_tokens(text) + + def text_to_ids(self, text, lang_id): + tokenizer = self.tokenizers_dict[lang_id] + token_ids = tokenizer.text_to_ids(text) + # token_ids[:] = [t + self.token_id_offset[lang_id] for t in token_ids] + + return token_ids + + def tokens_to_text(self, tokens, lang_id): + if isinstance(tokens, np.ndarray): + tokens = tokens.tolist() + + tokenizer = self.tokenizers_dict[lang_id] + return tokenizer.decode_pieces(tokens) + + def ids_to_text(self, ids, lang): + if isinstance(ids, np.ndarray): + ids = ids.tolist() + + tokens = [] + for id in ids: + # offset_id = self.offset_token_ids_by_token_id[id] + # tokenizer = self.tokenizers_by_token_id[id] + tokenizer = self.tokenizers_dict[lang] + # tokens.extend(tokenizer.ids_to_tokens([offset_id])) + tokens.extend(tokenizer.ids_to_tokens([id])) + text = ''.join(tokens).replace('▁', ' ') + + return text + + def token_to_id(self, token, lang_id): + tokenizer = self.tokenizers_dict[lang_id] + return tokenizer.token_to_id(token) + self.token_id_offset[lang_id] + + def ids_to_tokens(self, ids): + tokens = [] + + for id in ids: + offset_id = self.offset_token_ids_by_token_id[id] + tokenizer = self.tokenizers_by_token_id[id] + token = tokenizer.ids_to_tokens([offset_id])[0] + tokens.append(token) + + return tokens + + def ids_to_text_and_langs(self, ids): + text_and_langs = [] + + for id in ids: + offset_id = self.offset_token_ids_by_token_id[id] + tokenizer = self.tokenizers_by_token_id[id] + token = tokenizer.ids_to_tokens([offset_id])[0] + text = token.replace('▁', ' ') + text = text.strip() # strip for display purposes + lang = self.langs_by_token_id[id] + text_and_langs.append({'char': text, 'lang': lang}) + + return text_and_langs + + def ids_to_words_and_langs(self, ids): + words_and_langs = [] + + word_ids = [] # tokens belonging to the current word + for id in ids: + offset_id = self.offset_token_ids_by_token_id[id] + tokenizer = self.tokenizers_by_token_id[id] + token = tokenizer.ids_to_tokens([offset_id])[0] + if token.startswith('▁'): + if len(word_ids) > 0: # if this isn't the first word + word = self.ids_to_text(word_ids) + word = word.strip() # strip for display purposes + lang = self.ids_to_lang(word_ids) + wl = {'word': word, 'lang': lang} + words_and_langs.append(wl) + word_ids = [] + word_ids.append(id) + + if len(word_ids) > 0: # the last tokens + word = self.ids_to_text(word_ids) + word = word.strip() # strip for display purposes + lang = self.ids_to_lang(word_ids) + wl = {'word': word, 'lang': lang} + words_and_langs.append(wl) + + return words_and_langs + + def ids_to_lang(self, ids): + lang_cnts = {} + + for id in ids: + lang = self.langs_by_token_id[id] + lang_cnt = lang_cnts.get(lang) + if lang_cnt is not None: + lang_cnts[lang] = lang_cnt + 1 + else: + lang_cnts[lang] = 1 + + max_lang = '' + max_lang_cnt = -1 + for lang, lang_cnt in lang_cnts.items(): + if lang_cnt > max_lang_cnt: + max_lang = lang + max_lang_cnt = lang_cnt + + return max_lang + + def tokens_to_ids(self, tokens: Union[str, List[str]], langs: Union[str, List[str]]) -> Union[int, List[int]]: + if isinstance(tokens, str): + tokens = [tokens] + if isinstance(langs, str): + langs = [langs] + + ids = [] + for i, token in enumerate(tokens): + lang_id = langs[i] + ids.append(self.token_to_id(token, lang_id)) + return ids + + @property + def vocab(self): + return self.vocabulary + + @property + def langs(self): + return list(self.tokenizers_dict.keys()) diff --git a/scripts/speech_recognition/code_switching/code_switching_manifest_creation.py b/scripts/speech_recognition/code_switching/code_switching_manifest_creation.py index c783f803a..1f282230c 100644 --- a/scripts/speech_recognition/code_switching/code_switching_manifest_creation.py +++ b/scripts/speech_recognition/code_switching/code_switching_manifest_creation.py @@ -36,8 +36,8 @@ parser.add_argument( "--id_language2", default=None, type=str, help='Identifier for language 2, eg: en, es, hi', required=True ) -parser.add_argument("--max_sample_duration_sec", default=19, type=int, help='Maximum duration of sample (sec)') -parser.add_argument("--min_sample_duration_sec", default=16, type=int, help='Minimum duration of sample (sec)') +parser.add_argument("--max_sample_duration_sec", default=30, type=int, help='Maximum duration of sample (sec)') +parser.add_argument("--min_sample_duration_sec", default=20, type=int, help='Minimum duration of sample (sec)') parser.add_argument("--dataset_size_required_hrs", default=1, type=int, help='Duration of dataset required (hrs)') args = parser.parse_args()