Skip to content

Constant memory support #552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions lib/cudadrv/module/global.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
# - should be more dict-like: get and setindex(::name), haskey(::name)
# - globals(::Type)?

export
CuGlobal, get, set
export CuGlobal, get, set, CuGlobalArray


"""
Expand Down Expand Up @@ -62,3 +61,44 @@ function Base.setindex!(var::CuGlobal{T}, val::T) where T
val_ref = Ref{T}(val)
cuMemcpyHtoD_v2(var, val_ref, var.buf.bytesize)
end

"""
CuGlobalArray{T}(mod::CuModule, name::String, len::Integer)

Acquires a global array variable handle from a named global in a module.
"""
struct CuGlobalArray{T} # TODO: the functionality provided by this struct can most likely be merged into CuGlobal{T}
buf::Mem.DeviceBuffer
len::Integer

function CuGlobalArray{T}(mod::CuModule, name::String, len::Integer) where T
ptr_ref = Ref{CuPtr{Cvoid}}()
nbytes_ref = Ref{Csize_t}()
cuModuleGetGlobal_v2(ptr_ref, nbytes_ref, mod, name)
if nbytes_ref[] != (sizeof(T) * len)
throw(ArgumentError("size of global array '$name' ($(nbytes_ref[])) does not match given size (sizeof($T) * $length)"))
end
buf = Mem.DeviceBuffer(ptr_ref[], nbytes_ref[])

return new{T}(buf, len)
end
end

Base.eltype(::Type{CuGlobalArray{T}}) where {T} = T

Base.length(global_array::CuGlobalArray) = global_array.len

Base.sizeof(global_array::CuGlobalArray) = sizeof(global_array.buf)

function Base.copyto!(global_array::CuGlobalArray{T}, src::Array{T}) where {T}
if sizeof(src) != sizeof(global_array)
throw(DimensionMismatch("size of `src` ($(sizeof(src))) does not match global array ($(sizeof(global_array)))"))
end
cuMemcpyHtoD_v2(global_array.buf, src, sizeof(src))
end

function Base.collect(global_array::CuGlobalArray{T}) where {T}
val = Vector{T}(undef, length(global_array))
cuMemcpyDtoH_v2(val, global_array.buf, sizeof(global_array))
return val
end
2 changes: 2 additions & 0 deletions src/CUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ include("array.jl")
include("gpuarrays.jl")
include("utilities.jl")
include("texture.jl")
include("memory_constant.jl")
include("memory_global.jl")

# array libraries
include("../lib/complex.jl")
Expand Down
50 changes: 45 additions & 5 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ macro cuda(ex...)

# FIXME: macro hygiene wrt. escaping kwarg values (this broke with 1.5)
# we esc() the whole thing now, necessitating gensyms...
@gensym f_var kernel_f kernel_args kernel_tt kernel
@gensym f_var kernel_f kernel_args kernel_tt kernel memory_to_init
if dynamic
# FIXME: we could probably somehow support kwargs with constant values by either
# saving them in a global Dict here, or trying to pick them up from the Julia
Expand Down Expand Up @@ -98,8 +98,9 @@ macro cuda(ex...)
GC.@preserve $(vars...) $f_var begin
local $kernel_f = $cudaconvert($f_var)
local $kernel_args = map($cudaconvert, ($(var_exprs...),))
local $memory_to_init = $find_memory_to_init($f, ($(var_exprs...), ))
local $kernel_tt = Tuple{map(Core.Typeof, $kernel_args)...}
local $kernel = $cufunction($kernel_f, $kernel_tt; $(compiler_kwargs...))
local $kernel = $cufunction($kernel_f, $kernel_tt; memory_to_init=$memory_to_init, $(compiler_kwargs...))
if $launch
$kernel($(var_exprs...); $(call_kwargs...))
end
Expand All @@ -110,6 +111,20 @@ macro cuda(ex...)
return esc(code)
end

function find_memory_to_init(kernel, kernel_args)
fields = Any[getfield(kernel, i) for i in 1:nfields(kernel)]
append!(fields, kernel_args)

memory_to_init = []

for field in fields
if isa(field, CuConstantMemory) || isa(field, CuGlobalMemory)
push!(memory_to_init, field)
end
end

return memory_to_init
end

## host to device value conversion

Expand Down Expand Up @@ -201,17 +216,32 @@ struct HostKernel{F,TT} <: AbstractKernel{F,TT}
ctx::CuContext
mod::CuModule
fun::CuFunction
tracked_memory::Vector{Any}
end

@doc (@doc AbstractKernel) HostKernel

@inline function cudacall(kernel::HostKernel, tt, args...; config=nothing, kwargs...)
for memory in kernel.tracked_memory
global_array = CuGlobalArray{eltype(memory)}(kernel.mod, memory.name, length(memory))
copyto!(global_array, memory.value)
end

if config !== nothing
Base.depwarn("cudacall with config argument is deprecated, use `@cuda launch=false` and instrospect the returned kernel instead", :cudacall)
cudacall(kernel.fun, tt, args...; kwargs..., config(kernel)...)
else
cudacall(kernel.fun, tt, args...; kwargs...)
end

for memory in kernel.tracked_memory
if isa(memory, CuConstantMemory)
continue # constant memory is read only, skip it
end
global_array = CuGlobalArray{eltype(memory)}(kernel.mod, memory.name, length(memory))
new_value = reshape(collect(global_array), size(memory))
memory.value = deepcopy(new_value)
end
end

"""
Expand Down Expand Up @@ -283,13 +313,13 @@ The output of this function is automatically cached, i.e. you can simply call `c
in a hot path without degrading performance. New code will be generated automatically, when
when function changes, or when different types or keyword arguments are provided.
"""
function cufunction(f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where
function cufunction(f::F, tt::TT=Tuple{}; name=nothing, memory_to_init::Vector{Any}=[], kwargs...) where
{F<:Core.Function, TT<:Type}
dev = device()
cache = cufunction_cache[dev]
source = FunctionSpec(f, tt, true, name)
target = CUDACompilerTarget(dev; kwargs...)
params = CUDACompilerParams()
params = CUDACompilerParams(memory_to_init)
job = CompilerJob(target, source, params)
return GPUCompiler.cached_compilation(cache, job,
cufunction_compile,
Expand Down Expand Up @@ -358,7 +388,17 @@ function cufunction_link(@nospecialize(job::CompilerJob), compiled)
filter!(isequal("exception_flag"), compiled.external_gvars)
end

return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun)
tracked_memory = []

for memory in job.params.memory_to_init
global_array = CuGlobalArray{eltype(memory)}(mod, memory.name, length(memory))
copyto!(global_array, memory.value)
if memory.track_value_between_kernels
push!(tracked_memory, memory)
end
end

return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun, tracked_memory)
end

# https://github.com/JuliaLang/julia/issues/14919
Expand Down
7 changes: 6 additions & 1 deletion src/compiler/gpucompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ function CUDACompilerTarget(dev::CuDevice; kwargs...)
PTXCompilerTarget(; cap, exitable, debuginfo, kwargs...)
end

struct CUDACompilerParams <: AbstractCompilerParams end
struct CUDACompilerParams <: AbstractCompilerParams
memory_to_init::Vector{Any}
function CUDACompilerParams(memory_to_init::Vector{Any}=[])
new(memory_to_init)
end
end

CUDACompilerJob = CompilerJob{PTXCompilerTarget,CUDACompilerParams}

Expand Down
2 changes: 2 additions & 0 deletions src/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ include("intrinsics/memory_dynamic.jl")
include("intrinsics/atomics.jl")
include("intrinsics/misc.jl")
include("intrinsics/wmma.jl")
include("intrinsics/memory_constant.jl")
include("intrinsics/memory_global.jl")

# functionality from libdevice
#
Expand Down
75 changes: 75 additions & 0 deletions src/device/intrinsics/memory_constant.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Constant Memory

export CuDeviceConstantMemory

"""
CuDeviceConstantMemory{T,N,Name,Shape}

The device-side counterpart of [`CuConstantMemory{T,N}`](@ref). This type should not be used
directly except in the case of `CuConstantMemory` global variables, where it denotes the
type of the relevant kernel argument.

Note that the `Name` and `Shape` type variables are implementation details and it
discouraged to use them directly. Instead use [name(::CuConstantMemory)](@ref) and
[Base.size(::CuConstantMemory)](@ref) respectively.
"""
struct CuDeviceConstantMemory{T,N,Name,Shape} <: AbstractArray{T,N} end

"""
Get the name of underlying global variable of this `CuDeviceConstantMemory`.
"""
name(::CuDeviceConstantMemory{T,N,Name}) where {T,N,Name} = Name

Base.:(==)(A::CuDeviceConstantMemory, B::CuDeviceConstantMemory) = name(A) == name(B)
Base.hash(A::CuDeviceConstantMemory, h::UInt) = hash(name(A), h)

Base.size(::CuDeviceConstantMemory{T,N,Name,Shape}) where {T,N,Name,Shape} = Shape

Base.@propagate_inbounds Base.getindex(A::CuDeviceConstantMemory, i::Integer) = constmemref(A, i)

Base.IndexStyle(::Type{<:CuDeviceConstantMemory}) = Base.IndexLinear()

@inline function constmemref(A::CuDeviceConstantMemory{T,N,Name,Shape}, index::Integer) where {T,N,Name,Shape}
@boundscheck checkbounds(A, index)
len = length(A)
return read_constant_mem(Val(Name), index, T, Val(len))
end

@generated function read_constant_mem(::Val{global_name}, index::Integer, ::Type{T}, ::Val{len}) where {global_name,T,len}
JuliaContext() do ctx
# define LLVM types
T_int = convert(LLVMType, Int, ctx)
T_result = convert(LLVMType, T, ctx)

# define function and get LLVM module
param_types = [T_int]
llvm_f, _ = create_function(T_result, param_types)
mod = LLVM.parent(llvm_f)

# create a constant memory global variable
# TODO: global_var alignment?
T_global = LLVM.ArrayType(T_result, len)
global_var = GlobalVariable(mod, T_global, string(global_name), AS.Constant)
linkage!(global_var, LLVM.API.LLVMWeakAnyLinkage) # merge, but make sure symbols aren't discarded
initializer!(global_var, null(T_global))
extinit!(global_var, true)
# XXX: if we don't extinit, LLVM can inline the constant memory if it's predefined.
# that means we wouldn't be able to re-set it afterwards. do we want that?

# generate IR
Builder(ctx) do builder
entry = BasicBlock(llvm_f, "entry", ctx)
position!(builder, entry)

typed_ptr = inbounds_gep!(builder, global_var, [ConstantInt(0, ctx), parameters(llvm_f)[1]])
ld = load!(builder, typed_ptr)

metadata(ld)[LLVM.MD_tbaa] = tbaa_addrspace(AS.Constant, ctx)

ret!(builder, ld)
end

# call the function
call_function(llvm_f, T, Tuple{Int}, :((Int(index - one(index))),))
end
end
109 changes: 109 additions & 0 deletions src/device/intrinsics/memory_global.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Statically Allocated Global Memory

export CuDeviceGlobalMemory

struct CuDeviceGlobalMemory{T,N,Name,Shape} <: AbstractArray{T,N} end

name(::CuDeviceGlobalMemory{T,N,Name,Shape}) where {T,N,Name,Shape} = Name

Base.:(==)(A::CuDeviceGlobalMemory, B::CuDeviceGlobalMemory) = name(A) == name(B)
Base.hash(A::CuDeviceGlobalMemory, h::UInt) = hash(name(A), h)

Base.size(::CuDeviceGlobalMemory{T,N,Name,Shape}) where {T,N,Name,Shape} = Shape

Base.@propagate_inbounds Base.getindex(A::CuDeviceGlobalMemory, i::Integer) =
globalmemref(A, i)
Base.@propagate_inbounds Base.setindex!(A::CuDeviceGlobalMemory{T}, x, i::Integer) where {T} =
globalmemset(A, convert(T, x), i)

Base.IndexStyle(::Type{<:CuDeviceGlobalMemory}) = Base.IndexLinear()

@inline function globalmemref(A::CuDeviceGlobalMemory{T,N,Name,Shape}, index::Integer) where {T,N,Name,Shape}
@boundscheck checkbounds(A, index)
len = length(A)
return read_global_mem(Val(Name), index, T, Val(len))
end

@inline function globalmemset(A::CuDeviceGlobalMemory{T,N,Name,Shape}, x::T, index::Integer) where {T,N,Name,Shape}
@boundscheck checkbounds(A, index)
len = length(A)
write_global_mem(Val(Name), index, x, Val(len))
return A
end

@generated function read_global_mem(::Val{global_name}, index::Integer, ::Type{T}, ::Val{len}) where {global_name,T,len}
JuliaContext() do ctx
# define LLVM types
T_int = convert(LLVMType, Int, ctx)
T_result = convert(LLVMType, T, ctx)

# define function and get LLVM module
param_types = [T_int]
llvm_f, _ = create_function(T_result, param_types)
mod = LLVM.parent(llvm_f)

# create a global memory global variable
# TODO: global_var alignment?
T_global = LLVM.ArrayType(T_result, len)
global_var = GlobalVariable(mod, T_global, string(global_name), AS.Global)
linkage!(global_var, LLVM.API.LLVMWeakAnyLinkage) # merge, but make sure symbols aren't discarded
initializer!(global_var, null(T_global))
extinit!(global_var, true)
# XXX: if we don't extinit, LLVM can inline the constant memory if it's predefined.
# that means we wouldn't be able to re-set it afterwards. do we want that?

# generate IR
Builder(ctx) do builder
entry = BasicBlock(llvm_f, "entry", ctx)
position!(builder, entry)

typed_ptr = inbounds_gep!(builder, global_var, [ConstantInt(0, ctx), parameters(llvm_f)[1]])
ld = load!(builder, typed_ptr)

metadata(ld)[LLVM.MD_tbaa] = tbaa_addrspace(AS.Global, ctx)

ret!(builder, ld)
end

# call the function
call_function(llvm_f, T, Tuple{Int}, :((Int(index - one(index))),))
end
end

@generated function write_global_mem(::Val{global_name}, index::Integer, x::T, ::Val{len}) where {global_name,T,len}
JuliaContext() do ctx
# define LLVM types
T_int = convert(LLVMType, Int, ctx)
eltyp = convert(LLVMType, T, ctx)

# define function and get LLVM module
param_types = [eltyp, T_int]
llvm_f, _ = create_function(LLVM.VoidType(ctx), param_types)
mod = LLVM.parent(llvm_f)

# create a global memory global variable
# TODO: global_var alignment?
T_global = LLVM.ArrayType(eltyp, len)
global_var = GlobalVariable(mod, T_global, string(global_name), AS.Global)
linkage!(global_var, LLVM.API.LLVMWeakAnyLinkage) # merge, but make sure symbols aren't discarded
initializer!(global_var, null(T_global))
extinit!(global_var, true)

# generate IR
Builder(ctx) do builder
entry = BasicBlock(llvm_f, "entry", ctx)
position!(builder, entry)

typed_ptr = inbounds_gep!(builder, global_var, [ConstantInt(0, ctx), parameters(llvm_f)[2]])
val = parameters(llvm_f)[1]
st = store!(builder, val, typed_ptr)

metadata(st)[LLVM.MD_tbaa] = tbaa_addrspace(AS.Global, ctx)

ret!(builder)
end

# call the function
call_function(llvm_f, Cvoid, Tuple{T, Int}, :((x, Int(index - one(index)))))
end
end
Loading