Skip to content

Commit

Permalink
Change alphas to multi-dim indexing
Browse files Browse the repository at this point in the history
Change probs to multi-dim indexing in alpha kernel

Change probs to multi-dim indexing in beta kernel

Change beta coefficients to multi-dim indexing

Change output to multi-dim indexing

Update accum to multi-dim indexing

Update gpu kernel to multi-dim indexing
  • Loading branch information
maetshju committed Jul 27, 2020
1 parent b9072d2 commit fa6ea80
Showing 1 changed file with 40 additions and 68 deletions.
108 changes: 40 additions & 68 deletions src/losses/ctc-gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
using Flux
using Statistics
using CUDA
using DelimitedFiles

const MAX_THREADS = 256

function log_plus_f(p1, p2)

Expand Down Expand Up @@ -51,24 +52,21 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB
# Fill in first column (time step)
i = tid
while i <= last - start
alpha[start + i] = probs[labels[start + i]]
alpha[start+i, 1] = probs[labels[start+i], 1]
i += blockDim().x
end

sync_threads()

# Fill in coefficients for each time step
for t=2:T
startCurCol = (t-1) * S
startPrevCol = (t-2) * S
startProbCol = (t-1) * div(length(probs), T)

# Corner-case checking
if tid == 1 && !(1 < S - 2*(T-t) - 1)
if start == 0
alpha[startCurCol + 1] = probs[startProbCol + blankLabel] + alpha[startPrevCol + 1]
alpha[1, t] = probs[blankLabel, t] + alpha[1, t-1]
elseif start == 1
alpha[startCurCol + 1] = alpha[startPrevCol + 1]
alpha[1, t] = alpha[1, t-1]
end
end

Expand All @@ -79,16 +77,16 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB
idx = tid+1
while idx <= S

prevSum = log_plus_f(alpha[startPrevCol + idx], alpha[startPrevCol + idx-1])
prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1])

if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2]
prevSum = log_plus_f(prevSum, alpha[startPrevCol + idx-2])
prevSum = log_plus_f(prevSum, alpha[idx-2, t-1])
end

if idx < S - 2*(T-t) - 1
alpha[idx + startCurCol] = -Inf32
alpha[idx, t] = -Inf32
else
alpha[startCurCol + idx] = prevSum + probs[startProbCol + labels[idx]]
alpha[idx, t] = prevSum + probs[labels[idx], t]
end

idx += blockDim().x
Expand Down Expand Up @@ -122,52 +120,40 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength,

sync_threads()


startCurCol = (T-1)*S
startProbCol = (T-1) * div(length(probs), T)

i = tid

# Calculate coefficients for last column (time step)
# then determine alpha and beta product
while i <= last - start + 1

beta[startCurCol + i + start] = 0
output[startCurCol + i + start] = beta[startCurCol + i + start] + alphas[startCurCol + i + start]
beta[i+start, T] = 0
output[i+start, T] = beta[i+start, T] + alphas[i+start, T]
i += blockDim().x
end

sync_threads()

# Fill in `accum` for last column (time step)
if tid == 1
startAccCol = startProbCol
startOutputCol = startCurCol

if tid == 1
for i=1:S
labelIdx = labels[i]
accum[startAccCol + labelIdx] = log_plus_f(accum[startAccCol + labelIdx], output[startOutputCol + i])
accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T])
end
end

sync_threads()

# Fill in `grad` for last column (time step)
idx = tid
# while idx <= CUDA.div_fast(Float32(length(grad)), Float32(T))
while idx <= size(grad, 1)
#
startProbCol = (T - 1) * div(length(probs), T)
startOutputCol = (T - 1) * S

s = -Inf32

for i=1:S
s = log_plus_f(s, output[startOutputCol + i])
s = log_plus_f(s, output[i, T])
end

# ∂L/∂a (where a is activation before logsoftmax)
# grad[startProbCol + idx] = CUDA.exp(probs[startProbCol + idx]) - CUDA.exp(accum[startProbCol + idx] - s)
grad[idx, T] = CUDA.exp(probs[startProbCol + idx]) - CUDA.exp(accum[startProbCol + idx] - s)
grad[idx, T] = CUDA.exp(probs[idx, T]) - CUDA.exp(accum[idx, T] - s)
idx += blockDim().x
end

Expand All @@ -176,28 +162,23 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength,
# Fill in the rest of the coefficients
t = T-1
while t >= 1

startCurCol = (t-1)*S
startNextCol = t*S
startProbCol = t * div(length(probs), T)

if t < T

idx = tid
while idx <= S-1

nextSum = log_plus_f(beta[startNextCol + idx] + probs[startProbCol + labels[idx]],
beta[startNextCol + idx+1] + probs[startProbCol + labels[idx+1]])
nextSum = log_plus_f(beta[idx, t+1] + probs[labels[idx], t+1],
beta[idx+1, t+1] + probs[labels[idx+1], t+1])

if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2]
nextSum = log_plus_f(nextSum,
beta[startNextCol + idx + 2] + probs[startProbCol + labels[idx+2]])
beta[idx + 2, t+1] + probs[labels[idx+2], t+1])
end

if idx > 2*t
beta[idx + startCurCol] = -Inf32
beta[idx, t] = -Inf32
else
beta[idx + startCurCol] = nextSum
beta[idx, t] = nextSum

end

Expand All @@ -207,14 +188,14 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength,
sync_threads()

if tid == 1 && last == S
beta[startCurCol + S] = beta[startNextCol + S] + probs[startProbCol + blankLabel]
beta[S, t] = beta[S, t] + probs[blankLabel, t+1]
end

sync_threads()

idx = tid
while idx <= S
output[startCurCol + idx] = alphas[idx+startCurCol] + beta[startCurCol + idx]
output[idx, t] = alphas[idx, t] + beta[idx, t]
idx += blockDim().x
end

Expand All @@ -226,14 +207,10 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength,

# Calculate accumulated alpha-beta products for each label class for
# each time step; used in calculating gradients
if tid == 1

startAccCol = (t-1) * div(length(accum), T)
startOutputCol = (t-1) * S

if tid == 1
for i=1:S
labelIdx = labels[i]
accum[startAccCol + labelIdx] = log_plus_f(accum[startAccCol + labelIdx], output[startOutputCol + i])
accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t])
end
end

Expand All @@ -243,17 +220,15 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength,

# Calculate gradients
while idx <= size(grad, 1)
#
startProbCol = (t - 1) * div(length(probs), T)
startOutputCol = (t - 1) * S

s = -Inf32

for i=1:S
s = log_plus_f(s, output[startOutputCol + i])
s = log_plus_f(s, output[i, t])
end

# ∂L/∂a (where a is activation before logsoftmax)
grad[idx, t] = CUDA.exp(probs[startProbCol + idx]) - CUDA.exp(accum[startProbCol + idx] - s)
grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] - s)
idx += blockDim().x
end

Expand All @@ -266,20 +241,15 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength,
return nothing
end

# methods for `ctc_` helper function
ctc(ŷ::CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean

ctc(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))[1] |> mean

ctc(ŷ::CuArray, y::CuArray) = ctc_(ŷ, collect(y))[1] |> mean

# methods for `ctc_` helper function
ctc_(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), y)
ctc_(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))

function ctc_(ŷ::CuArray, y)

= logsoftmax(ŷ)
if any(isinf.(ŷ)) error("Inf in yhat") end
if any(isnan.(ŷ)) error("NaN in yhat") end

blank = size(ŷ, 1)
labels = [Base.argmax(y[:,i]) for i in 1:size(y, 2)]
Expand All @@ -289,23 +259,25 @@ function ctc_(ŷ::CuArray, y)
push!(z′, label)
push!(z′, blank)
end

T = size(ŷ, 2)
U′ = 2*length(z) + 1
alphas = CUDA.fill(log(zero(ŷ[1])), T * U′)
betas = copy(alphas)
output = copy(alphas)

alphas = CUDA.fill(log(zero(ŷ[1])), U′, T)
betas = CUDA.fill(log(zero(ŷ[1])), U′, T)
output = CUDA.fill(log(zero(ŷ[1])), U′, T)

nRepeats = countRepeats(labels)
nThreads = min(U′, MAX_THREADS)

# 1 block with `U′` threads
@cuda blocks=1 threads=U′ computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank)
@cuda blocks=1 threads=nThreads computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank)

grads = CUDA.fill(log(zero(ŷ[1])), size(ŷ))
accum = CUDA.fill(log(zero(ŷ[1])), length(ŷ))
accum = CUDA.fill(log(zero(ŷ[1])), size(ŷ))

@cuda blocks=1 threads=U′ computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank)
@cuda blocks=1 threads=nThreads computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank)

ls = reshape(collect(output), U′, T)
ls = collect(output)
ls = vec(-1 .* [logsum(ls[:,i]) for i in 1:size(ls, 2)])

= alphas = betas = output = accum = nothing
Expand Down

0 comments on commit fa6ea80

Please sign in to comment.