Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for calculating confidence scores with multilingual models #9

Merged
merged 3 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def rnnt_decoder_predictions_tensor(
if self.preserve_frame_confidence and (
self.preserve_word_confidence or self.preserve_token_confidence
):
hypotheses = self.compute_confidence(hypotheses)
hypotheses = self.compute_confidence(hypotheses, lang_ids)
return hypotheses, None

best_hyp_text = [h.text for h in hypotheses]
Expand Down Expand Up @@ -561,7 +561,7 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis], lang_ids: List[st

return hypotheses_list

def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothesis]:
def compute_confidence(self, hypotheses_list: List[Hypothesis], lang_ids: List[str] = None) -> List[Hypothesis]:
"""
Computes high-level (per-token and/or per-word) confidence scores for a list of hypotheses.
Assumes that `frame_confidence` is present in the hypotheses.
Expand Down Expand Up @@ -595,8 +595,11 @@ def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothes
offset += 1
hyp.token_confidence = token_confidence
if self.preserve_word_confidence:
for hyp in hypotheses_list:
hyp.word_confidence = self._aggregate_token_confidence(hyp)
for idx, hyp in enumerate(hypotheses_list):
if lang_ids:
hyp.word_confidence = self._aggregate_token_confidence(hyp, lang_ids[idx])
else:
hyp.word_confidence = self._aggregate_token_confidence(hyp)
return hypotheses_list

@abstractmethod
Expand Down Expand Up @@ -1401,7 +1404,7 @@ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec, blank
if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer):
self.decoding.set_decoding_type('subword')

def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]:
def _aggregate_token_confidence(self, hypothesis: Hypothesis, lang_id: str = None) -> List[float]:
"""
Implemented by subclass in order to reduce token confidence to a word-level confidence.

Expand All @@ -1414,7 +1417,7 @@ def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]:
A list of word-level confidence scores.
"""
return self._aggregate_token_confidence_subwords_sentencepiece(
hypothesis.words, hypothesis.token_confidence, hypothesis.y_sequence
hypothesis.words, hypothesis.token_confidence, hypothesis.y_sequence, lang_id
)

def decode_tokens_to_str(self, tokens: List[int], lang: str = None) -> str:
Expand All @@ -1431,9 +1434,10 @@ def decode_tokens_to_str(self, tokens: List[int], lang: str = None) -> str:
hypothesis = self.tokenizer.ids_to_text(tokens, lang)
else:
hypothesis = self.tokenizer.ids_to_text(tokens)

return hypothesis

def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]:
def decode_ids_to_tokens(self, tokens: List[int], lang: str = None) -> List[str]:
"""
Implemented by subclass in order to decode a token id list into a token list.
A token list is the string representation of each token id.
Expand All @@ -1444,7 +1448,10 @@ def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]:
Returns:
A list of decoded tokens.
"""
token_list = self.tokenizer.ids_to_tokens(tokens)
if lang is not None:
token_list = self.tokenizer.ids_to_tokens(tokens, lang)
else:
token_list = self.tokenizer.ids_to_tokens(tokens)
return token_list

def decode_tokens_to_lang(self, tokens: List[int]) -> str:
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/utils/asr_confidence_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def _aggregate_token_confidence_chars(self, words: List[str], token_confidence:
return word_confidence

def _aggregate_token_confidence_subwords_sentencepiece(
self, words: List[str], token_confidence: List[float], token_ids: List[int]
self, words: List[str], token_confidence: List[float], token_ids: List[int], lang_id: str = None
) -> List[float]:
"""Implementation of token confidence aggregation for subword-based models.

Expand All @@ -445,8 +445,8 @@ def _aggregate_token_confidence_subwords_sentencepiece(
prev_unk = False
prev_underline = False
for i, token_id in enumerate(token_ids):
token = self.decode_ids_to_tokens([int(token_id)])[0]
token_text = self.decode_tokens_to_str([int(token_id)])
token = self.decode_ids_to_tokens([int(token_id)], lang_id)[0]
token_text = self.decode_tokens_to_str([int(token_id)], lang_id)
# treat `<unk>` as a separate word regardless of the next token
# to match the result of `tokenizer.ids_to_text`
if (token != token_text or prev_unk) and i > j:
Expand Down
13 changes: 4 additions & 9 deletions nemo/collections/common/tokenizers/multilingual_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def ids_to_text(self, ids, lang):
ids = ids.tolist()

tokens = []
tokenizer = self.tokenizers_dict[lang]
for id in ids:
# offset_id = self.offset_token_ids_by_token_id[id]
# tokenizer = self.tokenizers_by_token_id[id]
tokenizer = self.tokenizers_dict[lang]
# tokens.extend(tokenizer.ids_to_tokens([offset_id]))
tokens.extend(tokenizer.ids_to_tokens([id]))
text = ''.join(tokens).replace('▁', ' ')
Expand All @@ -131,14 +131,9 @@ def token_to_id(self, token, lang_id):
tokenizer = self.tokenizers_dict[lang_id]
return tokenizer.token_to_id(token) + self.token_id_offset[lang_id]

def ids_to_tokens(self, ids):
tokens = []

for id in ids:
offset_id = self.offset_token_ids_by_token_id[id]
tokenizer = self.tokenizers_by_token_id[id]
token = tokenizer.ids_to_tokens([offset_id])[0]
tokens.append(token)
def ids_to_tokens(self, ids, lang_id):
tokenizer = self.tokenizers_dict[lang_id]
tokens = [tokenizer.ids_to_tokens([id])[0] for id in ids]

return tokens

Expand Down
Loading