Skip to content

Commit

Permalink
CuArrayBackend -> CUDABackend
Browse files Browse the repository at this point in the history
  • Loading branch information
leios committed Jul 22, 2024
1 parent 9590be3 commit 54599b8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
13 changes: 8 additions & 5 deletions src/CUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ include("compiler/execution.jl")
include("compiler/exceptions.jl")
include("compiler/reflection.jl")

# KernelAbstractions
include("CUDAKernels.jl")
import .CUDAKernels: CUDABackend, KA
export CUDABackend

# array implementation
include("gpuarrays.jl")
include("utilities.jl")
Expand Down Expand Up @@ -111,6 +116,9 @@ export CUBLAS, CUSPARSE, CUSOLVER, CUFFT, CURAND
const has_cusolvermg = CUSOLVER.has_cusolvermg
export has_cusolvermg

# KA Backend Definition
KA.get_backend(::CUSPARSE.AbstractCuSparseArray) = CUDABackend()

# random depends on CURAND
include("random.jl")

Expand All @@ -119,11 +127,6 @@ include("../lib/nvml/NVML.jl")
const has_nvml = NVML.has_nvml
export NVML, has_nvml

# KernelAbstractions
include("CUDAKernels.jl")
import .CUDAKernels: CUDABackend
export CUDABackend

# StaticArrays is still a direct dependency, so directly include the extension
include("../ext/StaticArraysExt.jl")

Expand Down
1 change: 0 additions & 1 deletion src/CUDAKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ KA.zeros(::CUDABackend, ::Type{T}, dims::Tuple) where T = CUDA.zeros(T, dims)
KA.ones(::CUDABackend, ::Type{T}, dims::Tuple) where T = CUDA.ones(T, dims)

KA.get_backend(::CuArray) = CUDABackend()
KA.get_backend(::CUSPARSE.AbstractCuSparseArray) = CUDABackend()
KA.synchronize(::CUDABackend) = synchronize()

Adapt.adapt_storage(::CUDABackend, a::Array) = Adapt.adapt(CuArray, a)
Expand Down
6 changes: 1 addition & 5 deletions src/gpuarrays.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
# GPUArrays.jl interface

import KernelAbstractions
import KernelAbstractions: Backend

#
# Device functionality
#


## execution

struct CuArrayBackend <: Backend end

@inline function GPUArrays.launch_heuristic(::CuArrayBackend, f::F, args::Vararg{Any,N};
@inline function GPUArrays.launch_heuristic(::CUDABackend, f::F, args::Vararg{Any,N};
elements::Int, elements_per_thread::Int) where {F,N}
kernel = @cuda launch=false f(CuKernelContext(), args...)

Expand Down

0 comments on commit 54599b8

Please sign in to comment.