Skip to content

Commit

Permalink
Fix corner-case bug in CPU ctc
Browse files Browse the repository at this point in the history
  • Loading branch information
maetshju committed Aug 9, 2020
1 parent fa6ea80 commit 3ee7b36
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions src/losses/ctc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function F(A, blank)
for curr in A[2:end]
if curr != prev && curr != blank
push!(z, curr)
`` end
end
prev = curr
end
return z
Expand Down Expand Up @@ -89,10 +89,8 @@ function ctc_(ŷ, y)
for t=1:T
for u=1:U′
if t == u == 1
# α[t,u] = ŷ[t, blank]
α[t,u] = ŷ[blank, t]
elseif t == 1 && u == 2
# α[t,u] = ŷ[t, z′[2]]
α[t,u] = ŷ[z′[2], t]
elseif t == 1 && u > 2
α[t,u] = -Inf
Expand All @@ -117,9 +115,10 @@ function ctc_(ŷ, y)
# Fill bottom-right corner so bounding errors can be avoided
# by starting `u` at `U′-1`
β[T,U′] = 0.0
β[T,U′-1] = 0.0

for t=T:-1:1
for u=(U′-1):-1:1
for u=U′:-1:1
if t == T && u >= U′ - 1
β[t,u] = 0.0
elseif t == T && u < U′ - 1
Expand All @@ -135,9 +134,6 @@ function ctc_(ŷ, y)
β[t, u] = logsum(v)
end
end
if t < T-1
β[t, U′] = β[t+1, U′] + ŷ[blank, t]
end
end

# Loss at each time t is taken as the sum of the product of the α and β coefficients for
Expand All @@ -163,7 +159,6 @@ function ctc_(ŷ, y)
end

losses = [x for x in losses]

return losses, grads
end

Expand Down

0 comments on commit 3ee7b36

Please sign in to comment.