Skip to content

Commit

Permalink
Merge pull request #104 from williamfgc/reorder
Browse files Browse the repository at this point in the history
Reorder GPU grid indices
  • Loading branch information
williamfgc authored Jul 1, 2024
2 parents 828f77e + e4146dd commit 01068b7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
6 changes: 3 additions & 3 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end
function JACC.parallel_for(
(L, M, N)::Tuple{I, I, I}, f::F, x...) where {
I <: Integer, F <: Function}
numThreads = 16
numThreads = 32
Lthreads = min(L, numThreads)
Mthreads = min(M, numThreads)
Nthreads = 1
Expand Down Expand Up @@ -93,9 +93,9 @@ function _parallel_for_amdgpu_MN(f, x...)
end

function _parallel_for_amdgpu_LMN(f, x...)
k = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
j = (workgroupIdx().y - 1) * workgroupDim().y + workitemIdx().y
i = (workgroupIdx().z - 1) * workgroupDim().z + workitemIdx().z
k = (workgroupIdx().z - 1) * workgroupDim().z + workitemIdx().z
f(i, j, k, x...)
return nothing
end
Expand Down
7 changes: 3 additions & 4 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ function JACC.parallel_for(
I <: Integer, F <: Function}
#To use JACC.shared, it is recommended to use a high number of threads per block to maximize the
# potential benefit from using shared memory.
#numThreads = 32
numThreads = 16
numThreads = 32
Lthreads = min(L, numThreads)
Mthreads = min(M, numThreads)
Nthreads = 1
Expand Down Expand Up @@ -106,9 +105,9 @@ function _parallel_for_cuda_MN(f, x...)
end

function _parallel_for_cuda_LMN(f, x...)
k = (blockIdx().x - 1) * blockDim().x + threadIdx().x
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
j = (blockIdx().y - 1) * blockDim().y + threadIdx().y
i = (blockIdx().z - 1) * blockDim().z + threadIdx().z
k = (blockIdx().z - 1) * blockDim().z + threadIdx().z
f(i, j, k, x...)
return nothing
end
Expand Down
4 changes: 2 additions & 2 deletions ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ function _parallel_for_oneapi_MN(f, x...)
end

function _parallel_for_oneapi_LMN(f, x...)
k = get_global_id(0)
i = get_global_id(0)
j = get_global_id(1)
i = get_global_id(2)
k = get_global_id(2)
f(i, j, k, x...)
return nothing
end
Expand Down

0 comments on commit 01068b7

Please sign in to comment.