diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 4524c0f54f..7162ec51fa 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -52,7 +52,7 @@ function F(A, blank) for curr in A[2:end] if curr != prev && curr != blank push!(z, curr) -`` end + end prev = curr end return z @@ -89,10 +89,8 @@ function ctc_(ŷ, y) for t=1:T for u=1:U′ if t == u == 1 -# α[t,u] = ŷ[t, blank] α[t,u] = ŷ[blank, t] elseif t == 1 && u == 2 -# α[t,u] = ŷ[t, z′[2]] α[t,u] = ŷ[z′[2], t] elseif t == 1 && u > 2 α[t,u] = -Inf @@ -117,9 +115,10 @@ function ctc_(ŷ, y) # Fill bottom-right corner so bounding errors can be avoided # by starting `u` at `U′-1` β[T,U′] = 0.0 + β[T,U′-1] = 0.0 for t=T:-1:1 - for u=(U′-1):-1:1 + for u=U′:-1:1 if t == T && u >= U′ - 1 β[t,u] = 0.0 elseif t == T && u < U′ - 1 @@ -135,9 +134,6 @@ function ctc_(ŷ, y) β[t, u] = logsum(v) end end - if t < T-1 - β[t, U′] = β[t+1, U′] + ŷ[blank, t] - end end # Loss at each time t is taken as the sum of the product of the α and β coefficients for @@ -163,7 +159,6 @@ function ctc_(ŷ, y) end losses = [x for x in losses] - return losses, grads end