forked from edward-zhu/umaru
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathctc_lua.lua
158 lines (119 loc) · 3.48 KB
/
ctc_lua.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
require 'utils/logs'
ctc = {}
--[[
getOnehotMatrix
target - a vector of number of class
return a L * C onehot Matrix, C is the number of kinds of classes.
]]
function ctc.getOnehotMatrix(target, class_num)
onehot = torch.zeros((#target)[1], class_num)
for i = 1, (#target)[1] do
c = target[i]
if c > 0 then
onehot[i][c] = 1
else
onehot[i][class_num] = 1
end
end
return onehot
end
--[[
getFilledTarget
target - a unicode string of ground truth
return a 2L + 1 vector of number of class.
]]
function ctc.getFilledTarget(target)
local filled = torch.zeros(#target * 2 + 1)
for i = 1, (#filled)[1] do
if i % 2 == 0 then
filled[i] = string.sub(target, i / 2, i / 2)
end
end
return filled
end
function ctc.toMatrix(outputTable)
local net = nn.Sequential()
net:add(nn.JoinTable(1))
net:add(nn.Reshape(#outputTable, 11))
return net:forward(outputTable)
end
--[[
getForwardVariable
calculate ForwardVariable for any (t, u)
- outputTable: a T * (2C + 1) matrix
- alignedTable: a T * L matrix
- target: a (2L + 1) * (2C + 1) matrix
]]--
function ctc.getForwardVariable(outputTable, alignedTable, target)
local T = (#outputTable)[1]
-- create a T * (2L + 1) Matrix
local L = (#target)[1]
local fvs = torch.zeros(T, L)
-- calculate using dynamic programming
-- initialize
fvs[1][1] = alignedTable[1][1]
fvs[1][2] = alignedTable[1][2]
local upper_bound = 2
-- calculate
for i = 2, T do
upper_bound = upper_bound + 2
if upper_bound > L then
upper_bound = L
end
for u = 1, upper_bound do
-- if l'[u] is not blank
if u % 2 == 1 then
fvs[i][u] = fvs[i][u] + fvs[i - 1][u]
if u > 1 then fvs[i][u] = fvs[i][u] + fvs[i - 1][u - 1] end
fvs[i][u] = fvs[i][u] * alignedTable[i][u]
else
if u > 2 and target[u - 2] ~= target[u] then fvs[i][u] = fvs[i][u] + fvs[i - 1][u - 2] end
if u > 1 then fvs[i][u] = fvs[i][u] + fvs[i - 1][u - 1] end
fvs[i][u] = fvs[i][u] + fvs[i - 1][u]
fvs[i][u] = fvs[i][u] * alignedTable[i][u]
end
end
end
return fvs
end
function ctc.getBackwardVariable(outputTable, alignedTable, target)
local T = (#outputTable)[1]
-- create a T * (2L + 1) Matrix
local L = (#target)[1]
local bvs = torch.zeros(T, L)
-- initialize
bvs[T][L] = 1
bvs[T][L - 1] = 1
-- calculate using dynamic programming
for i = T - 1, 1, -1 do
for u = L, 1, -1 do
if i % 2 == 1 then
bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u] * bvs[i + 1][u]
if u < L then bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u + 1] * bvs[i + 1][u + 1] end
else
bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u] * bvs[i + 1][u]
if u < L then bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u + 1] * bvs[i + 1][u + 1] end
if u < L - 1 and target[u + 2] ~= target[u] then
bvs[i][u] = bvs[i][u] + alignedTable[i + 1][u + 2] * bvs[i + 1][u + 2]
end
end
end
end
return bvs
end
function ctc.getCTCCost(outputTable, target)
-- convert target to one-hot Matrix (class + 1 * len(target))
local class_num = (#(outputTable[1]))[1]
target = ctc.getFilledTarget(target)
target = ctc.getOnehotMatrix(target, class_num)
outputTable = ctc.toMatrix(outputTable)
-- get aligned_table
-- outputTable: Tx(cls+1)
-- target: L'x(cls+1) --> targetT : (cls+1)xL'
-- alienged_table = TxL'
local alignedTable = outputTable * target:t()
fvs = ctc.getForwardVariable(outputTable, alignedTable, target)
-- calculate backwardVariable
bvs = ctc.getBackwardVariable(outputTable, alignedTable, target)
print(bvs)
end