Skip to content

Commit

Permalink
Merge pull request #9 from ryback123/nemo-v2
Browse files Browse the repository at this point in the history
Added support for calculating confidence scores with multilingual models
  • Loading branch information
kaushal-py authored Sep 29, 2024
2 parents c771ee8 + 89d12f3 commit af0138a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
23 changes: 15 additions & 8 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def rnnt_decoder_predictions_tensor(
if self.preserve_frame_confidence and (
self.preserve_word_confidence or self.preserve_token_confidence
):
hypotheses = self.compute_confidence(hypotheses)
hypotheses = self.compute_confidence(hypotheses, lang_ids)
return hypotheses, None

best_hyp_text = [h.text for h in hypotheses]
Expand Down Expand Up @@ -561,7 +561,7 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis], lang_ids: List[st

return hypotheses_list

def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothesis]:
def compute_confidence(self, hypotheses_list: List[Hypothesis], lang_ids: List[str] = None) -> List[Hypothesis]:
"""
Computes high-level (per-token and/or per-word) confidence scores for a list of hypotheses.
Assumes that `frame_confidence` is present in the hypotheses.
Expand Down Expand Up @@ -595,8 +595,11 @@ def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothes
offset += 1
hyp.token_confidence = token_confidence
if self.preserve_word_confidence:
for hyp in hypotheses_list:
hyp.word_confidence = self._aggregate_token_confidence(hyp)
for idx, hyp in enumerate(hypotheses_list):
if lang_ids:
hyp.word_confidence = self._aggregate_token_confidence(hyp, lang_ids[idx])
else:
hyp.word_confidence = self._aggregate_token_confidence(hyp)
return hypotheses_list

@abstractmethod
Expand Down Expand Up @@ -1401,7 +1404,7 @@ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec, blank
if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer):
self.decoding.set_decoding_type('subword')

def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]:
def _aggregate_token_confidence(self, hypothesis: Hypothesis, lang_id: str = None) -> List[float]:
"""
Implemented by subclass in order to reduce token confidence to a word-level confidence.
Expand All @@ -1414,7 +1417,7 @@ def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]:
A list of word-level confidence scores.
"""
return self._aggregate_token_confidence_subwords_sentencepiece(
hypothesis.words, hypothesis.token_confidence, hypothesis.y_sequence
hypothesis.words, hypothesis.token_confidence, hypothesis.y_sequence, lang_id
)

def decode_tokens_to_str(self, tokens: List[int], lang: str = None) -> str:
Expand All @@ -1431,9 +1434,10 @@ def decode_tokens_to_str(self, tokens: List[int], lang: str = None) -> str:
hypothesis = self.tokenizer.ids_to_text(tokens, lang)
else:
hypothesis = self.tokenizer.ids_to_text(tokens)

return hypothesis

def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]:
def decode_ids_to_tokens(self, tokens: List[int], lang: str = None) -> List[str]:
"""
Implemented by subclass in order to decode a token id list into a token list.
A token list is the string representation of each token id.
Expand All @@ -1444,7 +1448,10 @@ def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]:
Returns:
A list of decoded tokens.
"""
token_list = self.tokenizer.ids_to_tokens(tokens)
if lang is not None:
token_list = self.tokenizer.ids_to_tokens(tokens, lang)
else:
token_list = self.tokenizer.ids_to_tokens(tokens)
return token_list

def decode_tokens_to_lang(self, tokens: List[int]) -> str:
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/utils/asr_confidence_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def _aggregate_token_confidence_chars(self, words: List[str], token_confidence:
return word_confidence

def _aggregate_token_confidence_subwords_sentencepiece(
self, words: List[str], token_confidence: List[float], token_ids: List[int]
self, words: List[str], token_confidence: List[float], token_ids: List[int], lang_id: str = None
) -> List[float]:
"""Implementation of token confidence aggregation for subword-based models.
Expand All @@ -445,8 +445,8 @@ def _aggregate_token_confidence_subwords_sentencepiece(
prev_unk = False
prev_underline = False
for i, token_id in enumerate(token_ids):
token = self.decode_ids_to_tokens([int(token_id)])[0]
token_text = self.decode_tokens_to_str([int(token_id)])
token = self.decode_ids_to_tokens([int(token_id)], lang_id)[0]
token_text = self.decode_tokens_to_str([int(token_id)], lang_id)
# treat `<unk>` as a separate word regardless of the next token
# to match the result of `tokenizer.ids_to_text`
if (token != token_text or prev_unk) and i > j:
Expand Down
13 changes: 4 additions & 9 deletions nemo/collections/common/tokenizers/multilingual_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def ids_to_text(self, ids, lang):
ids = ids.tolist()

tokens = []
tokenizer = self.tokenizers_dict[lang]
for id in ids:
# offset_id = self.offset_token_ids_by_token_id[id]
# tokenizer = self.tokenizers_by_token_id[id]
tokenizer = self.tokenizers_dict[lang]
# tokens.extend(tokenizer.ids_to_tokens([offset_id]))
tokens.extend(tokenizer.ids_to_tokens([id]))
text = ''.join(tokens).replace('▁', ' ')
Expand All @@ -131,14 +131,9 @@ def token_to_id(self, token, lang_id):
tokenizer = self.tokenizers_dict[lang_id]
return tokenizer.token_to_id(token) + self.token_id_offset[lang_id]

def ids_to_tokens(self, ids):
tokens = []

for id in ids:
offset_id = self.offset_token_ids_by_token_id[id]
tokenizer = self.tokenizers_by_token_id[id]
token = tokenizer.ids_to_tokens([offset_id])[0]
tokens.append(token)
def ids_to_tokens(self, ids, lang_id):
tokenizer = self.tokenizers_dict[lang_id]
tokens = [tokenizer.ids_to_tokens([id])[0] for id in ids]

return tokens

Expand Down

0 comments on commit af0138a

Please sign in to comment.