-
Notifications
You must be signed in to change notification settings - Fork 20
/
ctc_decoder.py
90 lines (76 loc) · 3.5 KB
/
ctc_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from ctcdecode import CTCBeamDecoder
import torch.nn.functional as F
import torch
import numpy as np
import editdistance
class Decoder:
def __init__(self, labels, lm_path=None, alpha=1, beta=1.5, cutoff_top_n=40, cutoff_prob=0.99, beam_width=200, num_processes=24, blank_id=0):
self.vocab_list = ['_'] + labels # NOTE: blank symbol
self._decoder = CTCBeamDecoder(['_@'] + labels[1:], lm_path, alpha, beta, cutoff_top_n, cutoff_prob, beam_width, num_processes, blank_id)
# NOTE: the whitespace symbol is replaced with an @ symbol for explicit modeling in char-based LMs
def convert_to_string(self, tokens, seq_len=None):
if not seq_len:
seq_len = tokens.size(0)
out = []
for i in range(seq_len):
if len(out) == 0:
if tokens[i] != 0:
out.append(tokens[i])
else:
if tokens[i] != 0 and tokens[i] != tokens[i - 1]:
out.append(tokens[i])
return ''.join(self.vocab_list[i] for i in out)
def decode_beam(self, logits, seq_lens):
decoded = []
tlogits = logits.transpose(0, 1)
beam_result, beam_scores, timesteps, out_seq_len = self._decoder.decode(tlogits.softmax(-1), seq_lens)
for i in range(tlogits.size(0)):
output_str = ''.join(map(lambda x: self.vocab_list[x], beam_result[i][0][:out_seq_len[i][0]]))
decoded.append(output_str)
return decoded
def decode_greedy(self, logits, seq_lens):
decoded = []
tlogits = logits.transpose(0, 1)
_, tokens = torch.max(tlogits, 2)
for i in range(tlogits.size(0)):
output_str = self.convert_to_string(tokens[i], seq_lens[i])
decoded.append(output_str)
return decoded
def get_mean(self, decoded, gt, individual_length, func):
total_norm = 0.0
length = len(decoded)
for i in range(0, length):
val = float(func(decoded[i], gt[i]))
total_norm += val / individual_length
return total_norm / length
def wer(self, r, h):
# initialisation
d = np.zeros((len(r)+1)*(len(h)+1), dtype=np.uint8)
d = d.reshape((len(r)+1, len(h)+1))
for i in range(len(r)+1):
for j in range(len(h)+1):
if i == 0:
d[0][j] = j
elif j == 0:
d[i][0] = i
# computation
for i in range(1, len(r)+1):
for j in range(1, len(h)+1):
if r[i-1] == h[j-1]:
d[i][j] = d[i-1][j-1]
else:
substitution = d[i-1][j-1] + 1
insertion = d[i][j-1] + 1
deletion = d[i-1][j] + 1
d[i][j] = min(substitution, insertion, deletion)
return d[len(r)][len(h)]
def wer_sentence(self, r, h):
return self.wer(r.split(), h.split())
def cer_batch(self, decoded, gt):
assert len(decoded) == len(gt), 'batch size mismatch: {}!={}'.format(len(decoded), len(gt))
mean_indiv_len = np.mean([len(s) for s in gt])
return self.get_mean(decoded, gt, mean_indiv_len, editdistance.eval)
def wer_batch(self, decoded, gt):
assert len(decoded) == len(gt), 'batch size mismatch: {}!={}'.format(len(decoded), len(gt))
mean_indiv_len = np.mean([len(s.split()) for s in gt])
return self.get_mean(decoded, gt, mean_indiv_len, self.wer_sentence)