Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create an internal API which deploys off of JACCPreferences.backend #86

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
785d0af
Modify all CPU code
kmp5VT May 3, 2024
ca0d9de
More updates, GPU and array portion
kmp5VT May 3, 2024
50bd061
convert Array to array
kmp5VT May 3, 2024
694c2c8
Add comment
kmp5VT May 3, 2024
a8f9158
Allow JACC to precompile convert preference to symbol and use that to…
kmp5VT May 8, 2024
c1ffbbc
explicitly overwrite JACC.arraytype
kmp5VT May 8, 2024
a31b1d9
Add error about backend package not loaded.
kmp5VT May 8, 2024
17522a8
remove comment
kmp5VT May 9, 2024
70410a9
update AMDGPU tests
kmp5VT May 9, 2024
75a2040
update oneapi tests
kmp5VT May 9, 2024
5180159
Merge remote-tracking branch 'origin/main' into kmp5/refactor/internals
kmp5VT May 9, 2024
89b2301
Don't precompile
kmp5VT May 9, 2024
599aa7d
Move JACC_BACKEND_TYPE to JACCPreferences
kmp5VT May 9, 2024
fea569a
Merge branch 'main' into kmp5/refactor/internals
kmp5VT May 13, 2024
bb4fb35
Move JACCArrayType to its own file and bring into preferences.jl
kmp5VT May 13, 2024
5632023
Fix JACC call in JACC.BLAS
kmp5VT May 13, 2024
be2cec1
Small updates to threads test
kmp5VT May 13, 2024
04c9b65
Merge branch 'kmp5/refactor/internals' of https://github.com/kmp5VT/J…
kmp5VT May 13, 2024
d497ea8
Make sure eltypes are consistent in cuda tests
kmp5VT May 13, 2024
2f95204
Add back JACC BLAS, works fine on AMDGPU
kmp5VT May 13, 2024
28380a0
Merge branch 'main' into kmp5/refactor/internals
kmp5VT May 16, 2024
60e6387
Small updates to experimental `shared` function
kmp5VT May 16, 2024
09f6a33
`test_threads.jl` now can be run with any JACC backend
kmp5VT May 16, 2024
00f2ca8
Updates to tests since test_threads supports different backends
kmp5VT May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
module JACCAMDGPU

using JACC, AMDGPU

# overloaded array functions
include("array.jl")
using JACC: JACCArrayType

# overloaded experimental functions
include("JACCEXPERIMENTAL.jl")
using .experimental

function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_for(::JACCArrayType{<:ROCArray}, N::Integer, f::Function, x...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are trying to keep the API as simple as possible for domain scientists. Please see #62 for a similar discussion.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am confused. The outward facing API is exactly the same as the current implementation its just these two functions

function parallel_for(N, f::Function, x...)

function parallel_reduce(N, f::Function, x...)

Internally JACC will create a struct based on the JACC.JACCPreference.backend variable and will only allow users to utilize a single Array backend type just like the current implementation. The only real outward facing change for the current users is JACC.Array -> JACC.array as to not export and overwrite Base.Array which could cause issues in other libraries.

Copy link
Collaborator

@williamfgc williamfgc May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is breaking the current API. What I am trying to understand is the value added for this approach, we are adding internal boilerplate, but the end result in similar to the current implementation unless I am missing something. As for pre-compilation, we intentionally want to prevent compiling on non-supported platforms as we had a number of reported issues in the past on our systems. Multiple dispatch is a JIT "runtime concept" we are trying to avoid at all with weakdependencies.

Copy link
Collaborator

@williamfgc williamfgc May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See for example my comment here v0.0.3 wouldn't even compile when JACC.jl is consumed from an application and would just hang. Adding more applications help reveal a lot of these type of issues.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to not export and overwrite Base.Array which could cause issues in other libraries.

We are not overwriting Base.Array , JACC.Array is just an alias to existing Array types. Also, it's not clear what issues this could cause unless JACC.Array is misused outside its current scope. We can always implement a JACC.ToHost(JACC.Array)::Base.Array function to guarantee a host array.

Copy link
Author

@kmp5VT kmp5VT May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please provide an example where this implementation breaks the current API outside of migrating Array to an array function?

I have found some issues with the current main branch

julia> using JACC
julia> JACC.JACCPreferences.backend
"cuda"
julia> JACC.Array
Core.Array

If you don't have the correct backend loaded then JACC will still run with "threads" because JACC.Array = Base.Array. This would be unexpected behavior. The system I have written here forces the component matching JACC.JACCPreferences.backend to be loaded for the user facing API to work.

julia> using JACC
julia> begin
@show JACC.JACCPreferences.backend
    function seq_axpy(N, alpha, x, y)
        Threads.@threads for i in 1:N
            @inbounds x[i] += alpha * y[i]
        end
    end

    function axpy(i, alpha, x, y)
        if i <= length(x)
            @inbounds x[i] += alpha * y[i]
        end
    end

    N = 10
    # Generate random vectors x and y of length N for the interval [0, 100]
    x = round.(rand(Float32, N) * 100)
    y = round.(rand(Float32, N) * 100)
    alpha = 2.5

    x_host_JACC = Array(x)
    y_host_JACC = Array(y)
    JACC.parallel_for(N, axpy, alpha, x_host_JACC, y_host_JACC)
end
JACC.JACCPreferences.backend = "cuda"

ERROR: The backend cuda is either not recognized or the associated package is not loaded.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] arraytype(::Val{:cuda})
   @ JACC ~/.julia/dev/JACC.jl/src/helper.jl:15
 [3] JACC_BACKEND_TYPE()
   @ JACC ~/.julia/dev/JACC.jl/src/JACC.jl:14
 [4] parallel_for(::Int64, ::Function, ::Float64, ::Vector{Float32}, ::Vararg{Vector{Float32}})
   @ JACC ~/.julia/dev/JACC.jl/src/JACC.jl:17
 [5] top-level scope
   @ ~/.julia/dev/JACC.jl/test/tests_threads.jl:44

Currently the alias to array makes a global variable called Array and exports it which effectively overwrites the Base.Array . This can easily cause an issue outside of JACC

julia> using JACC
julia> begin 
       a = Array{Float32}(undef, 2,3)
       fill!(a, 10)
       a .+= 1
       end
WARNING: both JACC and Base export "Array"; uses of it in module Main must be qualified
ERROR: UndefVarError: `Array` not defined
Stacktrace:
 [1] top-level scope
   @ REPL[2]:2

The system here proposed does not need to alias Array because the alias is effectively stored as a compiled parameter inside the module.

Another issue is that if a different backend introduced after loading JACC, the definition of JACC.Array changes

julia> using JACC
using[ Info: Precompiling JACC [0979c8fe-16a4-4796-9b82-89a9f10403ea]
[ Info: Skipping precompilation since __precompile__(false). Importing JACC [0979c8fe-16a4-4796-9b82-89a9f10403ea].
julia> using CUDA
JACC.Precompiling JACCCUDA
  ? JACC → JACCCUDA
[ Info: Precompiling JACCCUDA [2fb45ac4-0993-536e-a71a-0b5526d52098]
┌ Warning: Module JACC with build ID ffffffff-ffff-ffff-000a-3724e6101b14 is missing from the cache.
│ This may mean JACC [0979c8fe-16a4-4796-9b82-89a9f10403ea] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:1948
[ Info: Skipping precompilation since __precompile__(false). Importing JACCCUDA [2fb45ac4-0993-536e-a71a-0b5526d52098].
julia> JACC.JACCPreferences.backend
"threads"
julia> JACC.Array
CuArray
julia> using AMDGPU
┌ Warning: Device libraries are unavailable, device intrinsics will be disabled.
└ @ AMDGPU ~/.julia/packages/AMDGPU/gtxsf/src/AMDGPU.jl:186
julia> JACC.Array
ROCArray

This is not possible with the implementation proposed because the forward facing API effectively chooses the backend based solely on the JACCPreferences.backend variable.

I understand that you don't want to mix backends and that is specifically why I designed the code to internally construct the JACCArrayType based on the Symbol(JACCPreferences.backend) variable. To enforce this we should only document the two functions in JACC.jl and suggest using the JACC.JACCPreferences.set_backend variable to change the backend arraytype.

In regards to the JIT and multiple dispatch, I am still learning Julia so I could be wrong but I am not sure that is a concern based on the design here. The function JACC_BACKEND_TYPE() is compiler infer- able and defined at compile time

julia> @code_warntype JACC.JACC_BACKEND_TYPE()
MethodInstance for JACC.JACC_BACKEND_TYPE()
  from JACC_BACKEND_TYPE() @ JACC ~/.julia/dev/JACC.jl/src/JACC.jl:13
Arguments
  #self#::Core.Const(JACC.JACC_BACKEND_TYPE)
Body::JACC.JACCArrayType{Array}
1 ─ %1 = JACC.JACCArrayType::Core.Const(JACC.JACCArrayType)
│   %2 = JACC.JACCPreferences.backend::Core.Const("threads")
│   %3 = JACC.Symbol(%2)::Core.Const(:threads)
│   %4 = JACC.Val(%3)::Core.Const(Val{:threads}())
│   %5 = JACC.arraytype(%4)::Core.Const(Array)
│   %6 = Core.apply_type(%1, %5)::Core.Const(JACC.JACCArrayType{Array})
│   %7 = (%6)()::Core.Const(JACC.JACCArrayType{Array}())
└──      return %7
julia> @code_typed JACC.JACC_BACKEND_TYPE()
CodeInfo(
1 ─     return $(QuoteNode(JACC.JACCArrayType{Array}()))
) => JACC.JACCArrayType{Array}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please provide an example where this implementation breaks the current API outside of migrating Array to an array function?

I think this is self-explanatory.

Currently the alias to array makes a global variable called Array and exports it which effectively overwrites the Base.Array . This can easily cause an issue outside of JACC

using X is certainly not recommended, but import X due to name clashing, I think this is different from overriding Base.Array.

This is not possible with the implementation proposed because the forward facing API effectively chooses the backend based solely on the JACCPreferences.backend variable.

In regards to the JIT and multiple dispatch, I am still learning Julia so I could be wrong but I am not sure that is a concern based on the design here. The function JACC_BACKEND_TYPE() is compiler infer- able and defined at compile time

Packages still need to be downloaded and precompiled. That's where we saw the reported issues on our systems. Hence why we are being very conservative as explained in our last call. Even more with relatively new Julia version, features and rapidly evolving stack for AMD and Intel GPUs.

I just think the motivation and problems tackled for our systems by the current API and implementation are very different from those in this PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamfgc I do not understand your systems as the code patterns here work on my clusters without issue and all of the unittests in the module pass. I would love to understand how this PR breaks on your systems if you could run the code in this PR and give me examples that would be appreciated. I specifically tailored the design this PR from our conversations on Zoom. Further, I connected with @pedrovalerolara to collaborate with you on code design with the potential of integrating JACC into the ITensors. I do not have a specific use cases for your code and have not used any of your software previously.

Copy link
Collaborator

@williamfgc williamfgc May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be a good starting point for collaboration. Feel free to integrate JACC with ITensors and understand the trade-offs and the motivation of the current design before changing all internals. As discussed in the call, we target DOE HPC systems, which are not simple to deploy and use, especially with new stuff like the Julia ecosystem.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative idea: can we do struct Array{T,N} end and then have Array{T,N}(...) = JACC.arraytype(){T,N}(...)? This would preserve the idea that JACC.Array can become another array type, without messing with modifying module globals at runtime.

numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
@roc groupsize=threads gridsize=blocks _parallel_for_amdgpu(f, x...)
AMDGPU.synchronize()
end

function JACC.parallel_for(
(M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_for(::JACCArrayType{<:ROCArray},
(M, N)::Tuple{Integer, Integer}, f::Function, x...)
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand All @@ -29,8 +27,8 @@ function JACC.parallel_for(
AMDGPU.synchronize()
end

function JACC.parallel_reduce(
N::I, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::JACCArrayType{<:ROCArray},
N::Integer, f::Function, x...)
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
Expand All @@ -45,8 +43,8 @@ function JACC.parallel_reduce(
return rret
end

function JACC.parallel_reduce(
(M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::JACCArrayType{<:ROCArray},
(M, N)::Tuple{Integer, Integer}, f::Function, x...)
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand Down Expand Up @@ -313,8 +311,9 @@ function reduce_kernel_amdgpu_MN((M, N), red, ret)
return nothing
end

JACC.arraytype(::Val{:amdgpu}) = ROCArray

function __init__()
const JACC.Array = AMDGPU.ROCArray{T, N} where {T, N}
end

end # module JACCAMDGPU
3 changes: 2 additions & 1 deletion ext/JACCAMDGPU/JACCEXPERIMENTAL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module experimental

using JACC, AMDGPU

function JACC.experimental.shared(x::ROCDeviceArray{T,N}) where {T,N}
function JACC.experimental.shared(x::ROCDeviceArray)
T = eltype(x)
size = length(x)
shmem = @ROCDynamicLocalArray(T, size)
num_threads = workgroupDim().x * workgroupDim().y
Expand Down
8 changes: 0 additions & 8 deletions ext/JACCAMDGPU/array.jl

This file was deleted.

22 changes: 10 additions & 12 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
module JACCCUDA

using JACC, CUDA
using JACC: JACCArrayType

# overloaded array functions
include("array.jl")

# overloaded experimental functions
include("JACCEXPERIMENTAL.jl")
using .experimental

function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_for(::JACCArrayType{<:CuArray}, N::Integer, f::Function, x...)
parallel_args = (N, f, x...)
parallel_kargs = cudaconvert.(parallel_args)
parallel_tt = Tuple{Core.Typeof.(parallel_kargs)...}
Expand All @@ -20,8 +17,8 @@ function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
parallel_kernel(parallel_kargs...; threads = threads, blocks = blocks)
end

function JACC.parallel_for(
(M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_for(::JACCArrayType{<:CuArray},
(M, N)::Tuple{Integer, Integer}, f::Function, x...)
#To use JACC.shared, it is recommended to use a high number of threads per block to maximize the
# potential benefit from using shared memory.
#numThreads = 32
Expand All @@ -37,8 +34,8 @@ function JACC.parallel_for(
# f, x...)
end

function JACC.parallel_reduce(
N::I, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::JACCArrayType{<:CuArray},
N::Integer, f::Function, x...)
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
Expand All @@ -51,8 +48,8 @@ function JACC.parallel_reduce(
return rret
end

function JACC.parallel_reduce(
(M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::JACCArrayType{<:CuArray},
(M, N)::Tuple{Integer, Integer}, f::Function, x...)
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand Down Expand Up @@ -322,8 +319,9 @@ function reduce_kernel_cuda_MN((M, N), red, ret)
return nothing
end

JACC.arraytype(::Val{:cuda}) = CuArray

function __init__()
const JACC.Array = CUDA.CuArray{T, N} where {T, N}
end

end # module JACCCUDA
3 changes: 2 additions & 1 deletion ext/JACCCUDA/JACCEXPERIMENTAL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module experimental

using JACC, CUDA

function JACC.experimental.shared(x::CuDeviceArray{T,N}) where {T,N}
function JACC.experimental.shared(x::CuDeviceArray)
T = eltype(x)
size = length(x)
shmem = @cuDynamicSharedMem(T, size)
num_threads = blockDim().x * blockDim().y
Expand Down
8 changes: 0 additions & 8 deletions ext/JACCCUDA/array.jl

This file was deleted.

3 changes: 2 additions & 1 deletion ext/JACCONEAPI/JACCEXPERIMENTAL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module experimental

using JACC, oneAPI

function JACC.experimental.shared(x::oneDeviceArray{T,N}) where {T,N}
function JACC.experimental.shared(x::oneDeviceArray)
T = eltype(x)
size = length(x)
shmem = oneLocalArray(T, size)
num_threads = get_local_size(0) * get_local_size(1)
Expand Down
21 changes: 10 additions & 11 deletions ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
module JACCONEAPI

using JACC, oneAPI

# overloaded array functions
include("array.jl")
using JACC: JACCArrayType

# overloaded experimental functions
include("JACCEXPERIMENTAL.jl")
using .experimental

function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_for(::JACCArrayType{<:oneArray}, N::Integer, f::Function, x...)
#maxPossibleItems = oneAPI.oneL0.compute_properties(device().maxTotalGroupSize)
maxPossibleItems = 256
items = min(N, maxPossibleItems)
groups = ceil(Int, N / items)
oneAPI.@sync @oneapi items=items groups=groups _parallel_for_oneapi(f, x...)
end

function JACC.parallel_for(
(M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_for(::JACCArrayType{<:oneArray},
(M, N)::Tuple{Integer, Integer}, f::Function, x...)
maxPossibleItems = 16
Mitems = min(M, maxPossibleItems)
Nitems = min(N, maxPossibleItems)
Expand All @@ -28,8 +26,8 @@ function JACC.parallel_for(
f, x...)
end

function JACC.parallel_reduce(
N::I, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::JACCArrayType{<:oneArray},
N::Integer, f::Function, x...)
numItems = 256
items = min(N, numItems)
groups = ceil(Int, N / items)
Expand All @@ -41,8 +39,8 @@ function JACC.parallel_reduce(
return rret
end

function JACC.parallel_reduce(
(M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::JACCArrayType{<:oneArray},
(M, N)::Tuple{Integer, Integer}, f::Function, x...)
numItems = 16
Mitems = min(M, numItems)
Nitems = min(N, numItems)
Expand Down Expand Up @@ -306,8 +304,9 @@ function reduce_kernel_oneapi_MN((M, N), red, ret)
return nothing
end

JACC.arraytype(::Val{:oneapi}) = oneArray

function __init__()
const JACC.Array = oneAPI.oneArray{T, N} where {T, N}
end

end # module JACCONEAPI
50 changes: 7 additions & 43 deletions src/JACC.jl
Original file line number Diff line number Diff line change
@@ -1,68 +1,32 @@
__precompile__(false)
module JACC

import Atomix: @atomic
using Atomix: @atomic
# module to set back end preferences
include("JACCArrayType.jl")
include("JACCPreferences.jl")
include("helper.jl")
# overloaded array functions
include("array.jl")


include("JACCBLAS.jl")
using .BLAS

include("JACCEXPERIMENTAL.jl")
using .experimental

export Array, @atomic
export parallel_for

global Array

function parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
@maybe_threaded for i in 1:N
f(i, x...)
end
end

function parallel_for(
(M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
@maybe_threaded for j in 1:N
for i in 1:M
f(i, j, x...)
end
end
end

function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
@maybe_threaded for i in 1:N
tmp[Threads.threadid()] = tmp[Threads.threadid()] .+ f(i, x...)
end
for i in 1:Threads.nthreads()
ret = ret .+ tmp[i]
end
return ret
function parallel_for(N, f::Function, x...)
return parallel_for(JACCPreferences.JACC_BACKEND_TYPE(), N, f, x...)
end

function parallel_reduce(
(M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
@maybe_threaded for j in 1:N
for i in 1:M
tmp[Threads.threadid()] = tmp[Threads.threadid()] .+ f(i, j, x...)
end
end
for i in 1:Threads.nthreads()
ret = ret .+ tmp[i]
end
return ret
function parallel_reduce(N, f::Function, x...)
return parallel_reduce(JACCPreferences.JACC_BACKEND_TYPE(), N, f, x...)
end

function __init__()
const JACC.Array = Base.Array{T, N} where {T, N}
end

end # module JACC
8 changes: 8 additions & 0 deletions src/JACCArrayType.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
struct JACCArrayType{T}
end

arraytype() = arraytype(Val(Symbol(JACCPreferences.backend)))
arraytype(::Val{:threads}) = Array
arraytype(::Val{T}) where T = error("The backend $(T) is either not recognized or the associated package is not loaded.")
arraytype(J::JACCArrayType) = arraytype(typeof(J))
arraytype(::Type{<:JACCArrayType{T}}) where {T} = T
4 changes: 2 additions & 2 deletions src/JACCBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ function _dot(i, x, y)
return @inbounds x[i] * y[i]
end

function axpy(n::I, alpha, x, y) where {I<:Integer}
function axpy(n::Integer, alpha, x, y)
JACC.parallel_for(n, _axpy, alpha, x, y)
end

function dot(n::I, x, y) where {I<:Integer}
function dot(n::Integer, x, y)
JACC.parallel_reduce(n, _dot, x, y)
end

Expand Down
2 changes: 1 addition & 1 deletion src/JACCEXPERIMENTAL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module experimental

using JACC

function shared(x::Base.Array{T,N}) where {T,N}
function shared(x)
return x
end

Expand Down
6 changes: 6 additions & 0 deletions src/JACCPreferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,10 @@ end

const backend = @load_preference("backend", "threads")

using JACC: JACCArrayType, arraytype

function JACC_BACKEND_TYPE()
return JACCArrayType{arraytype(Val(Symbol(JACCPreferences.backend)))}()
end

end # module JACCPreferences
Loading