Skip to content

Commit

Permalink
Add multi-softmax architecture for CTC, RNN-T and Hybrid models
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushal-py committed Jan 15, 2024
1 parent 9b1774e commit 3bddb03
Show file tree
Hide file tree
Showing 25 changed files with 1,983 additions and 192 deletions.
44 changes: 32 additions & 12 deletions nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,27 @@ def _speech_collate_fn(batch, pad_id):
assumes the signals are 1d torch tensors (i.e. mono audio).
"""
packed_batch = list(zip(*batch))
if len(packed_batch) == 5:
if len(packed_batch) == 6: # has language ids
_, audio_lengths, _, tokens_lengths, sample_ids, language_ids = packed_batch
elif len(packed_batch) == 5: # has sample ids
language_ids = None
_, audio_lengths, _, tokens_lengths, sample_ids = packed_batch
elif len(packed_batch) == 4:
sample_ids = None
sample_ids, language_ids = None, None
_, audio_lengths, _, tokens_lengths = packed_batch
else:
raise ValueError("Expects 4 or 5 tensors in the batch!")
raise ValueError("Expects 4 or 5 or 6 tensors in the batch!")
max_audio_len = 0
has_audio = audio_lengths[0] is not None
if has_audio:
max_audio_len = max(audio_lengths).item()
max_tokens_len = max(tokens_lengths).item()

audio_signal, tokens = [], []
for b in batch:
if len(b) == 5:
if len(b) == 6:
sig, sig_len, tokens_i, tokens_i_len, _, _ = b
elif len(b) == 5:
sig, sig_len, tokens_i, tokens_i_len, _ = b
else:
sig, sig_len, tokens_i, tokens_i_len = b
Expand All @@ -97,12 +102,14 @@ def _speech_collate_fn(batch, pad_id):
audio_signal, audio_lengths = None, None
tokens = torch.stack(tokens)
tokens_lengths = torch.stack(tokens_lengths)
if sample_ids is None:
return audio_signal, audio_lengths, tokens, tokens_lengths
else:
if language_ids is not None:
sample_ids = torch.tensor(sample_ids, dtype=torch.int32)
return audio_signal, audio_lengths, tokens, tokens_lengths, sample_ids, list(language_ids)
elif sample_ids is not None:
sample_ids = torch.tensor(sample_ids, dtype=torch.int32)
return audio_signal, audio_lengths, tokens, tokens_lengths, sample_ids

else:
return audio_signal, audio_lengths, tokens, tokens_lengths

class ASRManifestProcessor:
"""
Expand Down Expand Up @@ -424,6 +431,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
'transcripts': NeuralType(('B', 'T'), LabelsType()),
'transcript_length': NeuralType(tuple('B'), LengthsType()),
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
'language_id': [NeuralType(('B'), StringType(), optional=True)],
}

def __init__(
Expand All @@ -441,6 +449,7 @@ def __init__(
eos_id: Optional[int] = None,
pad_id: int = 0,
return_sample_id: bool = False,
return_language_id: bool = False,
channel_selector: Optional[ChannelSelectorType] = None,
):
if type(manifest_filepath) == str:
Expand All @@ -462,6 +471,7 @@ def __init__(
self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
self.trim = trim
self.return_sample_id = return_sample_id
self.return_language_id = return_language_id
self.channel_selector = channel_selector

def get_manifest_sample(self, sample_id):
Expand All @@ -486,8 +496,10 @@ def __getitem__(self, index):

t, tl = self.manifest_processor.process_text_by_sample(sample=sample)

if self.return_sample_id:
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index
if self.return_language_id:
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index, sample.lang
elif self.return_sample_id:
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), index
else:
output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long()

Expand Down Expand Up @@ -530,6 +542,7 @@ class AudioToCharDataset(_AudioTextDataset):
bos_id: Id of beginning of sequence symbol to append if not None
eos_id: Id of end of sequence symbol to append if not None
return_sample_id (bool): whether to return the sample_id as a part of each sample
return_language_id (bool): whether to return the language_id as a part of each sample
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
"""

Expand All @@ -543,6 +556,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
'transcripts': NeuralType(('B', 'T'), LabelsType()),
'transcript_length': NeuralType(tuple('B'), LengthsType()),
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
'language_id': [NeuralType(('B'), StringType(), optional=True)],
}

def __init__(
Expand All @@ -564,6 +578,7 @@ def __init__(
pad_id: int = 0,
parser: Union[str, Callable] = 'en',
return_sample_id: bool = False,
return_language_id: bool = False,
channel_selector: Optional[ChannelSelectorType] = None,
):
self.labels = labels
Expand All @@ -586,6 +601,7 @@ def __init__(
eos_id=eos_id,
pad_id=pad_id,
return_sample_id=return_sample_id,
return_language_id=return_language_id,
channel_selector=channel_selector,
)

Expand Down Expand Up @@ -624,6 +640,7 @@ class AudioToBPEDataset(_AudioTextDataset):
use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS]
tokens to beginning and ending of speech respectively.
return_sample_id (bool): whether to return the sample_id as a part of each sample
return_language_id (bool): whether to return the language_id as a part of each sample
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
"""

Expand All @@ -637,6 +654,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
'transcripts': NeuralType(('B', 'T'), LabelsType()),
'transcript_length': NeuralType(tuple('B'), LengthsType()),
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
'language_id': [NeuralType(('B'), StringType(), optional=True)],
}

def __init__(
Expand All @@ -652,6 +670,7 @@ def __init__(
trim: bool = False,
use_start_end_token: bool = True,
return_sample_id: bool = False,
return_language_id: bool = False,
channel_selector: Optional[ChannelSelectorType] = None,
):
if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0:
Expand All @@ -671,7 +690,7 @@ def __init__(

class TokenizerWrapper:
def __init__(self, tokenizer):
if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer):
if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer) or isinstance(tokenizer, tokenizers.multilingual_tokenizer.MultilingualTokenizer):
self.is_aggregate = True
else:
self.is_aggregate = False
Expand Down Expand Up @@ -701,6 +720,7 @@ def __call__(self, *args):
pad_id=pad_id,
trim=trim,
return_sample_id=return_sample_id,
return_language_id=return_language_id,
channel_selector=channel_selector,
)

Expand Down
5 changes: 4 additions & 1 deletion nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def get_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None)
trim=config.get('trim_silence', False),
parser=config.get('parser', 'en'),
return_sample_id=config.get('return_sample_id', False),
return_language_id=config.get('return_language_id', False),
channel_selector=config.get('channel_selector', None),
)
return dataset
Expand Down Expand Up @@ -231,6 +232,7 @@ def get_bpe_dataset(
trim=config.get('trim_silence', False),
use_start_end_token=config.get('use_start_end_token', True),
return_sample_id=config.get('return_sample_id', False),
return_language_id=config.get('return_language_id', False),
channel_selector=config.get('channel_selector', None),
)
return dataset
Expand Down Expand Up @@ -630,7 +632,8 @@ def get_audio_to_text_bpe_dataset_from_config(
logging.warning(f"`concat_sampling_probabilities` need to sum to 1. Config: {config}")
return None

shuffle = config['shuffle']
if config.get('shuffle', False):
shuffle = False
device = 'gpu' if torch.cuda.is_available() else 'cpu'
if config.get('use_dali', False):
device_id = local_rank if device == 'gpu' else None
Expand Down
11 changes: 10 additions & 1 deletion nemo/collections/asr/losses/ssl_losses/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,17 @@ def forward(self, spectrograms, spec_masks, decoder_outputs, decoder_lengths=Non
targets.transpose(0, 1), targets_masked_only.size(0), # TxBxC # T'
)
else:
# only sample from masked steps in utterance
negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0)) # T'xBxC # T'
# if targets_masked_only.size(0) >= self.num_negatives:
# # only sample from masked steps in utterance
# negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0)) # T'xBxC # T'
# else: # for shorter samples (<8s)
# # print(f"sampling from non-masked ({self.num_negatives},{targets_masked_only.size(0)})")
# # sample from all steps in utterance
# negatives, _ = self.sample_negatives(
# targets.transpose(0, 1), targets_masked_only.size(0), # TxBxC # T'
# )

# NxT'xBxC

out_masked_only = out_masked_only.reshape(-1, out_masked_only.shape[-1])
Expand Down
15 changes: 10 additions & 5 deletions nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def rnnt_decoder_predictions_tensor(
encoded_lengths: torch.Tensor,
return_hypotheses: bool = False,
partial_hypotheses: Optional[List[Hypothesis]] = None,
lang_ids: List[str] = None,
) -> Tuple[List[str], Optional[List[List[str]]], Optional[Union[Hypothesis, NBestHypotheses]]]:
"""
Decode an encoder output by autoregressive decoding of the Decoder+Joint networks.
Expand All @@ -408,9 +409,10 @@ def rnnt_decoder_predictions_tensor(
Look at rnnt_utils.NBestHypotheses for more information.
"""
# Compute hypotheses
# print("Decode strategy:", self.cfg.strategy)
with torch.inference_mode():
hypotheses_list = self.decoding(
encoder_output=encoder_output, encoded_lengths=encoded_lengths, partial_hypotheses=partial_hypotheses
encoder_output=encoder_output, encoded_lengths=encoded_lengths, partial_hypotheses=partial_hypotheses, language_ids=lang_ids,
) # type: [List[Hypothesis]]

# extract the hypotheses
Expand All @@ -424,7 +426,7 @@ def rnnt_decoder_predictions_tensor(

for nbest_hyp in prediction_list: # type: NBestHypotheses
n_hyps = nbest_hyp.n_best_hypotheses # Extract all hypotheses for this sample
decoded_hyps = self.decode_hypothesis(n_hyps) # type: List[str]
decoded_hyps = self.decode_hypothesis(n_hyps, lang_ids) # type: List[str]

# If computing timestamps
if self.compute_timestamps is True:
Expand All @@ -443,7 +445,7 @@ def rnnt_decoder_predictions_tensor(
return best_hyp_text, all_hyp_text

else:
hypotheses = self.decode_hypothesis(prediction_list) # type: List[str]
hypotheses = self.decode_hypothesis(prediction_list, lang_ids) # type: List[str]

# If computing timestamps
if self.compute_timestamps is True:
Expand All @@ -462,7 +464,7 @@ def rnnt_decoder_predictions_tensor(
best_hyp_text = [h.text for h in hypotheses]
return best_hyp_text, None

def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hypothesis, NBestHypotheses]]:
def decode_hypothesis(self, hypotheses_list: List[Hypothesis], lang_ids: List[str] = None) -> List[Union[Hypothesis, NBestHypotheses]]:
"""
Decode a list of hypotheses into a list of strings.
Expand Down Expand Up @@ -498,7 +500,10 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp
token_repetitions = [1] * len(alignments) # preserve number of repetitions per token
hypothesis = (prediction, alignments, token_repetitions)
else:
hypothesis = self.decode_tokens_to_str(prediction)
if lang_ids is not None:
hypothesis = self.decode_tokens_to_str(prediction, lang_ids[ind])
else:
hypothesis = self.decode_tokens_to_str(prediction)

# TODO: remove
# collapse leading spaces before . , ? for PC models
Expand Down
24 changes: 16 additions & 8 deletions nemo/collections/asr/metrics/rnnt_wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ class RNNTBPEDecoding(AbstractRNNTDecoding):
tokenizer: The tokenizer which will be used for decoding.
"""

def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec):
blank_id = tokenizer.tokenizer.vocab_size
def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec, blank_id=None):
if blank_id is None:
blank_id = tokenizer.tokenizer.vocab_size
self.tokenizer = tokenizer

super(RNNTBPEDecoding, self).__init__(
Expand All @@ -222,7 +223,7 @@ def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]:
hypothesis.words, hypothesis.token_confidence, hypothesis.y_sequence
)

def decode_tokens_to_str(self, tokens: List[int]) -> str:
def decode_tokens_to_str(self, tokens: List[int], lang: str = None) -> str:
"""
Implemented by subclass in order to decoder a token list into a string.
Expand All @@ -232,7 +233,10 @@ def decode_tokens_to_str(self, tokens: List[int]) -> str:
Returns:
A decoded string.
"""
hypothesis = self.tokenizer.ids_to_text(tokens)
if lang is not None:
hypothesis = self.tokenizer.ids_to_text(tokens, lang)
else:
hypothesis = self.tokenizer.ids_to_text(tokens)
return hypothesis

def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]:
Expand Down Expand Up @@ -275,7 +279,7 @@ def decode_ids_to_langs(self, tokens: List[int]) -> List[str]:
lang_list = self.tokenizer.ids_to_text_and_langs(tokens)
return lang_list

def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hypothesis, NBestHypotheses]]:
def decode_hypothesis(self, hypotheses_list: List[Hypothesis], lang_ids: List[str] = None) -> List[Union[Hypothesis, NBestHypotheses]]:
"""
Decode a list of hypotheses into a list of strings.
Overrides the super() method optionally adding lang information
Expand All @@ -286,7 +290,7 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp
Returns:
A list of strings.
"""
hypotheses = super().decode_hypothesis(hypotheses_list)
hypotheses = super().decode_hypothesis(hypotheses_list, lang_ids)
if self.compute_langs:
if isinstance(self.tokenizer, AggregateTokenizer):
for ind in range(len(hypotheses_list)):
Expand Down Expand Up @@ -371,6 +375,7 @@ def update(
encoded_lengths: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
lang_ids: List[str] = None,
) -> torch.Tensor:
words = 0
scores = 0
Expand All @@ -385,10 +390,13 @@ def update(
for ind in range(targets_cpu_tensor.shape[0]):
tgt_len = tgt_lenths_cpu_tensor[ind].item()
target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist()
reference = self.decoding.decode_tokens_to_str(target)
if lang_ids is not None:
reference = self.decoding.decode_tokens_to_str(target, lang_ids[ind])
else:
reference = self.decoding.decode_tokens_to_str(target)
references.append(reference)

hypotheses, _ = self.decoding.rnnt_decoder_predictions_tensor(encoder_output, encoded_lengths)
hypotheses, _ = self.decoding.rnnt_decoder_predictions_tensor(encoder_output, encoded_lengths, lang_ids=lang_ids)

if self.log_prediction:
logging.info(f"\n")
Expand Down
Loading

0 comments on commit 3bddb03

Please sign in to comment.