Skip to content

Commit

Permalink
Fix GPU beta kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
maetshju committed Aug 30, 2020
1 parent 3ee7b36 commit ff6ddf5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
14 changes: 10 additions & 4 deletions src/losses/ctc-gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,16 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength,
if t < T

idx = tid
while idx <= S-1
# while idx <= S-1
while idx <= S

nextSum = log_plus_f(beta[idx, t+1] + probs[labels[idx], t+1],
beta[idx+1, t+1] + probs[labels[idx+1], t+1])
nextSum = beta[idx, t+1] + probs[labels[idx], t+1]

if idx < S

nextSum = log_plus_f(nextSum,
beta[idx+1, t+1] + probs[labels[idx+1], t+1])
end

if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2]
nextSum = log_plus_f(nextSum,
Expand Down Expand Up @@ -279,7 +285,7 @@ function ctc_(ŷ::CuArray, y)

ls = collect(output)
ls = vec(-1 .* [logsum(ls[:,i]) for i in 1:size(ls, 2)])

= alphas = betas = output = accum = nothing
return ls, grads
end
13 changes: 11 additions & 2 deletions test/ctc-gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ end
g1 = gradient(ctc, x_cu, y_cu)[1]
g1 = collect(g1)

g2 = ctc_ngradient(x, y)[1]
g2 = ctc_ngradient(x_cu, y_cu)[1] |> collect

@test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5))

# test that GPU loss matches CPU implementation

l1 = Flux.Losses.ctc_(x_cu, y_cu)[1]
l2 = Flux.Losses.ctc_(x, y)[1]
l2 = Flux.Losses.ctc_(Float32.(x), y)[1]

@test all(isapprox.(l1, l2, rtol=1e-5, atol=1e-5))

Expand All @@ -64,5 +64,14 @@ end
ghat = gradient(ctc, x_cu, y_cu)[1] |> collect

@test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5))

x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray
y_cu = [1 1 0 0; 0 0 1 1; 0 0 0 0] |> CuArray
@test ctc(x_cu, y_cu) 8.02519869363453

g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07]

ghat = gradient(ctc, x_cu, y_cu)[1] |> collect
@test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5))

end

0 comments on commit ff6ddf5

Please sign in to comment.