Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: custom operators for parallel_reduce #120

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,21 @@ function JACC.parallel_for(
CUDA.@sync @cuda threads = (Lthreads, Mthreads, Nthreads) blocks = (Lblocks, Mblocks, Nblocks) shmem = shmem_size _parallel_for_cuda_LMN(f, x...)
end

function JACC.parallel_reduce(
N::I, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(N::Integer, op, f::Function, x...; init)
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
ret = CUDA.zeros(Float64, blocks)
rret = CUDA.zeros(Float64, 1)
ret = fill!(CUDA.CuArray{typeof(init)}(undef, 1), init)
rret = CUDA.CuArray([init])
CUDA.@sync @cuda threads=threads blocks=blocks shmem=512 * sizeof(Float64) _parallel_reduce_cuda(
N, ret, f, x...)
N, op, ret, f, x...)
CUDA.@sync @cuda threads=threads blocks=1 shmem=512 * sizeof(Float64) reduce_kernel_cuda(
blocks, ret, rret)
return rret
blocks, op, ret, rret)
return Base.Array(rret)[]
end

function JACC.parallel_reduce(N::Integer, f::Function, x...)
return JACC.parallel_reduce(N, +, f, x...; init = zero(Float64))
end

function JACC.parallel_reduce(
Expand Down Expand Up @@ -113,7 +116,7 @@ function _parallel_for_cuda_LMN(f, x...)
return nothing
end

function _parallel_reduce_cuda(N, ret, f, x...)
function _parallel_reduce_cuda(N, op, ret, f, x...)
shared_mem = @cuDynamicSharedMem(Float64, 512)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
ti = threadIdx().x
Expand All @@ -126,52 +129,52 @@ function _parallel_reduce_cuda(N, ret, f, x...)
end
sync_threads()
if (ti <= 256)
shared_mem[ti] += shared_mem[ti + 256]
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 256])
end
sync_threads()
if (ti <= 128)
shared_mem[ti] += shared_mem[ti + 128]
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 128])
end
sync_threads()
if (ti <= 64)
shared_mem[ti] += shared_mem[ti + 64]
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 64])
end
sync_threads()
if (ti <= 32)
shared_mem[ti] += shared_mem[ti + 32]
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 32])
end
sync_threads()
if (ti <= 16)
shared_mem[ti] += shared_mem[ti + 16]
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 16])
end
sync_threads()
if (ti <= 8)
shared_mem[ti] += shared_mem[ti + 8]
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 8])
end
sync_threads()
if (ti <= 4)
shared_mem[ti] += shared_mem[ti + 4]
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 4])
end
sync_threads()
if (ti <= 2)
shared_mem[ti] += shared_mem[ti + 2]
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 2])
end
sync_threads()
if (ti == 1)
shared_mem[ti] += shared_mem[ti + 1]
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 1])
ret[blockIdx().x] = shared_mem[ti]
end
return nothing
end

function reduce_kernel_cuda(N, red, ret)
function reduce_kernel_cuda(N, op, red, ret)
shared_mem = @cuDynamicSharedMem(Float64, 512)
i = threadIdx().x
ii = i
tmp::Float64 = 0.0
if N > 512
while ii <= N
tmp += @inbounds red[ii]
tmp = op(tmp, @inbounds red[ii])
ii += 512
end
elseif (i <= N)
Expand All @@ -180,39 +183,39 @@ function reduce_kernel_cuda(N, red, ret)
shared_mem[threadIdx().x] = tmp
sync_threads()
if (i <= 256)
shared_mem[i] += shared_mem[i + 256]
shared_mem[i] = op(shared_mem[i], shared_mem[i + 256])
end
sync_threads()
if (i <= 128)
shared_mem[i] += shared_mem[i + 128]
shared_mem[i] = op(shared_mem[i], shared_mem[i + 128])
end
sync_threads()
if (i <= 64)
shared_mem[i] += shared_mem[i + 64]
shared_mem[i] = op(shared_mem[i], shared_mem[i + 64])
end
sync_threads()
if (i <= 32)
shared_mem[i] += shared_mem[i + 32]
shared_mem[i] = op(shared_mem[i], shared_mem[i + 32])
end
sync_threads()
if (i <= 16)
shared_mem[i] += shared_mem[i + 16]
shared_mem[i] = op(shared_mem[i], shared_mem[i + 16])
end
sync_threads()
if (i <= 8)
shared_mem[i] += shared_mem[i + 8]
shared_mem[i] = op(shared_mem[i], shared_mem[i + 8])
end
sync_threads()
if (i <= 4)
shared_mem[i] += shared_mem[i + 4]
shared_mem[i] = op(shared_mem[i], shared_mem[i + 4])
end
sync_threads()
if (i <= 2)
shared_mem[i] += shared_mem[i + 2]
shared_mem[i] = op(shared_mem[i], shared_mem[i + 2])
end
sync_threads()
if (i == 1)
shared_mem[i] += shared_mem[i + 1]
shared_mem[i] = op(shared_mem[i], shared_mem[i + 1])
ret[1] = shared_mem[1]
end
return nothing
Expand Down
14 changes: 9 additions & 5 deletions src/JACC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,22 @@ function parallel_for(
end
end

function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
function parallel_reduce(N::Integer, op, f::Function, x...; init)
ret = init
tmp = fill(init, Threads.nthreads())
@maybe_threaded for i in 1:N
tmp[Threads.threadid()] = tmp[Threads.threadid()] .+ f(i, x...)
tmp[Threads.threadid()] = op.(tmp[Threads.threadid()], f(i, x...))
end
for i in 1:Threads.nthreads()
ret = ret .+ tmp[i]
ret = op.(ret, tmp[i])
end
return ret
end

function parallel_reduce(N::Integer, f::Function, x...)
return parallel_reduce(N, +, f, x...; init = zeros(1))
end

function parallel_reduce(
(M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
tmp = zeros(Threads.nthreads())
Expand Down
132 changes: 70 additions & 62 deletions test/tests_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,69 +97,77 @@ end
@test zeros(N)≈Array(x) rtol=1e-5
end

# @testset "CG" begin

# function matvecmul(i, a1, a2, a3, x, y, SIZE)
# if i == 1
# y[i] = a2[i] * x[i] + a1[i] * x[i+1]
# elseif i == SIZE
# y[i] = a3[i] * x[i-1] + a2[i] * x[i]
# elseif i > 1 && i < SIZE
# y[i] = a3[i] * x[i-1] + a1[i] * +x[i] + a1[i] * +x[i+1]
# end
# end

# function dot(i, x, y)
# @inbounds return x[i] * y[i]
# end

# function axpy(i, alpha, x, y)
# @inbounds x[i] += alpha[1, 1] * y[i]
# end

# SIZE = 10
# a0 = JACC.ones(Float64, SIZE)
# a1 = JACC.ones(Float64, SIZE)
# a2 = JACC.ones(Float64, SIZE)
# r = JACC.ones(Float64, SIZE)
# p = JACC.ones(Float64, SIZE)
# s = JACC.zeros(Float64, SIZE)
# x = JACC.zeros(Float64, SIZE)
# r_old = JACC.zeros(Float64, SIZE)
# r_aux = JACC.zeros(Float64, SIZE)
# a1 = a1 * 4
# r = r * 0.5
# p = p * 0.5
# global cond = one(Float64)

# while cond[1, 1] >= 1e-14

# r_old = copy(r)

# JACC.parallel_for(SIZE, matvecmul, a0, a1, a2, p, s, SIZE)

# alpha0 = JACC.parallel_reduce(SIZE, dot, r, r)
# alpha1 = JACC.parallel_reduce(SIZE, dot, p, s)

# alpha = alpha0 / alpha1
# negative_alpha = alpha * (-1.0)

# JACC.parallel_for(SIZE, axpy, negative_alpha, r, s)
# JACC.parallel_for(SIZE, axpy, alpha, x, p)

# beta0 = JACC.parallel_reduce(SIZE, dot, r, r)
# beta1 = JACC.parallel_reduce(SIZE, dot, r_old, r_old)
# beta = beta0 / beta1

# r_aux = copy(r)
@testset "CG" begin

function matvecmul(i, a1, a2, a3, x, y, SIZE)
if i == 1
y[i] = a2[i] * x[i] + a1[i] * x[i+1]
elseif i == SIZE
y[i] = a3[i] * x[i-1] + a2[i] * x[i]
elseif i > 1 && i < SIZE
y[i] = a3[i] * x[i-1] + a1[i] * +x[i] + a1[i] * +x[i+1]
end
end

function dot(i, x, y)
@inbounds return x[i] * y[i]
end

function axpy(i, alpha, x, y)
@inbounds x[i] += alpha[1, 1] * y[i]
end

SIZE = 10
a0 = JACC.ones(Float64, SIZE)
a1 = JACC.ones(Float64, SIZE)
a2 = JACC.ones(Float64, SIZE)
r = JACC.ones(Float64, SIZE)
p = JACC.ones(Float64, SIZE)
s = JACC.zeros(Float64, SIZE)
x = JACC.zeros(Float64, SIZE)
r_old = JACC.zeros(Float64, SIZE)
r_aux = JACC.zeros(Float64, SIZE)
a1 = a1 * 4
r = r * 0.5
p = p * 0.5
cond = one(Float64)

while cond[1, 1] >= 1e-14

r_old = copy(r)

JACC.parallel_for(SIZE, matvecmul, a0, a1, a2, p, s, SIZE)

alpha0 = JACC.parallel_reduce(SIZE, dot, r, r)
alpha1 = JACC.parallel_reduce(SIZE, dot, p, s)

alpha = alpha0 / alpha1
negative_alpha = alpha * (-1.0)

JACC.parallel_for(SIZE, axpy, negative_alpha, r, s)
JACC.parallel_for(SIZE, axpy, alpha, x, p)

beta0 = JACC.parallel_reduce(SIZE, dot, r, r)
beta1 = JACC.parallel_reduce(SIZE, dot, r_old, r_old)
beta = beta0 / beta1

r_aux = copy(r)

JACC.parallel_for(SIZE, axpy, beta, r_aux, p)
ccond = JACC.parallel_reduce(SIZE, dot, r, r)
cond = ccond
p = copy(r_aux)
end
@test cond[1, 1] <= 1e-14
end

# JACC.parallel_for(SIZE, axpy, beta, r_aux, p)
# ccond = JACC.parallel_reduce(SIZE, dot, r, r)
# global cond = ccond
# p = copy(r_aux)
# end
# @test cond[1, 1] <= 1e-14
# end
@testset "reduce" begin
SIZE = 100
ah = randn(SIZE)
ad = JACC.Array(ah)
mxd = JACC.parallel_reduce(SIZE, max, (i,a)->a[i], ad; init = -Inf)
@test mxd == maximum(ah)
end

# @testset "LBM" begin

Expand Down
Loading