Skip to content

Commit 02ffcaa

Browse files
committed
Perform eager initialization, retain undefined-ness, pass initializer as Val.
1 parent 0dbaf72 commit 02ffcaa

File tree

3 files changed

+90
-90
lines changed

3 files changed

+90
-90
lines changed

src/compiler/gpucompiler.jl

-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ function GPUCompiler.finish_module!(job::CUDACompilerJob, mod::LLVM.Module)
3535
Tuple{CompilerJob{PTXCompilerTarget}, typeof(mod)},
3636
job, mod)
3737
emit_exception_flag!(mod)
38-
emit_constant_memory_initializer!(mod)
3938
end
4039

4140
function GPUCompiler.link_libraries!(job::CUDACompilerJob, mod::LLVM.Module,

src/device/intrinsics/memory_constant.jl

+63-7
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ Note that the `Name` and `Shape` type variables are implementation details and i
1313
discouraged to use them directly. Instead use [name(::CuConstantMemory)](@ref) and
1414
[Base.size(::CuConstantMemory)](@ref) respectively.
1515
"""
16-
struct CuDeviceConstantMemory{T,N,Name,Shape} <: AbstractArray{T,N} end
16+
struct CuDeviceConstantMemory{T,N,Name,Shape,Hash} <: AbstractArray{T,N} end
1717

1818
"""
1919
Get the name of underlying global variable of this `CuDeviceConstantMemory`.
2020
"""
21-
name(::CuDeviceConstantMemory{T,N,Name,Shape}) where {T,N,Name,Shape} = Name
21+
name(::CuDeviceConstantMemory{T,N,Name}) where {T,N,Name} = Name
2222

2323
Base.:(==)(A::CuDeviceConstantMemory, B::CuDeviceConstantMemory) = name(A) == name(B)
2424
Base.hash(A::CuDeviceConstantMemory, h::UInt) = hash(name(A), h)
@@ -29,13 +29,13 @@ Base.@propagate_inbounds Base.getindex(A::CuDeviceConstantMemory, i::Integer) =
2929

3030
Base.IndexStyle(::Type{<:CuDeviceConstantMemory}) = Base.IndexLinear()
3131

32-
@inline function constmemref(A::CuDeviceConstantMemory{T,N,Name,Shape}, index::Integer) where {T,N,Name,Shape}
32+
@inline function constmemref(A::CuDeviceConstantMemory{T,N,Name,Shape,Init}, index::Integer) where {T,N,Name,Shape,Init}
3333
@boundscheck checkbounds(A, index)
3434
len = length(A)
35-
return read_constant_mem(Val(Name), index, T, Val(len))
35+
return read_constant_mem(Val(Name), index, T, Val(Shape), Val(Init))
3636
end
3737

38-
@generated function read_constant_mem(::Val{global_name}, index::Integer, ::Type{T}, ::Val{len}) where {global_name,T,len}
38+
@generated function read_constant_mem(::Val{global_name}, index::Integer, ::Type{T}, ::Val{shape}, ::Val{init}) where {global_name,T,shape,init}
3939
JuliaContext() do ctx
4040
# define LLVM types
4141
T_int = convert(LLVMType, Int, ctx)
@@ -47,11 +47,67 @@ end
4747
mod = LLVM.parent(llvm_f)
4848

4949
# create a constant memory global variable
50+
# TODO: global_var alignment?
51+
len = prod(shape)
5052
T_global = LLVM.ArrayType(T_result, len)
5153
global_var = GlobalVariable(mod, T_global, string(global_name), AS.Constant)
52-
linkage!(global_var, LLVM.API.LLVMExternalLinkage) # NOTE: external linkage is the default
54+
linkage!(global_var, LLVM.API.LLVMWeakAnyLinkage) # merge, but make sure symbols aren't discarded
5355
extinit!(global_var, true)
54-
# TODO: global_var alignment?
56+
# XXX: if we don't extinit, LLVM can inline the constant memory if it's predefined.
57+
# that means we wouldn't be able to re-set it afterwards. do we want that?
58+
59+
# initialize the constant memory
60+
if init !== nothing
61+
arr = reshape([init...], shape)
62+
if isnothing(arr)
63+
GPUCompiler.@safe_error "calling kernel containing garbage collected constant memory"
64+
end
65+
66+
flattened_arr = reduce(vcat, arr)
67+
typ = eltype(T_global)
68+
69+
# TODO: have a look at how julia converts structs to llvm:
70+
# https://github.com/JuliaLang/julia/blob/80ace52b03d9476f3d3e6ff6da42f04a8df1cf7b/src/cgutils.cpp#L572
71+
# this only seems to emit a type though
72+
init = if isa(typ, LLVM.IntegerType) || isa(typ, LLVM.FloatingPointType)
73+
ConstantArray(flattened_arr, ctx)
74+
elseif isa(typ, LLVM.ArrayType) # a struct with every field of the same type gets optimized to an array
75+
constant_arrays = LLVM.Constant[]
76+
for x in flattened_arr
77+
fields = collect(map(name->getfield(x, name), fieldnames(typeof(x))))
78+
constant_array = ConstantArray(fields, ctx)
79+
push!(constant_arrays, constant_array)
80+
end
81+
ConstantArray(typ, constant_arrays)
82+
elseif isa(typ, LLVM.StructType)
83+
constant_structs = LLVM.Constant[]
84+
for x in flattened_arr
85+
constants = LLVM.Constant[]
86+
for fieldname in fieldnames(typeof(x))
87+
field = getfield(x, fieldname)
88+
if isa(field, Bool)
89+
# NOTE: Bools get compiled to i8 instead of the more "correct" type i1
90+
push!(constants, ConstantInt(LLVM.Int8Type(ctx), field))
91+
elseif isa(field, Integer)
92+
push!(constants, ConstantInt(field, ctx))
93+
elseif isa(field, AbstractFloat)
94+
push!(constants, ConstantFP(field, ctx))
95+
else
96+
GPUCompiler.@safe_error "constant memory does not currently support structs with non-primitive fields ($(typeof(x)).$fieldname::$(typeof(field)))"
97+
end
98+
end
99+
const_struct = ConstantStruct(typ, constants)
100+
push!(constant_structs, const_struct)
101+
end
102+
ConstantArray(typ, constant_structs)
103+
else
104+
# unreachable, but let's be safe and throw a nice error message just in case
105+
GPUCompiler.@safe_error "Could not emit initializer for constant memory of type $typ"
106+
nothing
107+
end
108+
109+
init !== nothing && initializer!(global_var, init)
110+
end
55111

56112
# generate IR
57113
Builder(ctx) do builder

src/memory_constant.jl

+27-82
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
export CuConstantMemory
22

3-
# Map a constant memory name to its array value
4-
const constant_memory_initializer = Dict{Symbol,WeakRef}()
5-
63
"""
74
CuConstantMemory{T,N}(value::Array{T,N})
85
CuConstantMemory{T}(::UndefInitializer, dims::Integer...)
@@ -30,35 +27,44 @@ In cases where the same kernel object gets called mutiple times, and it is desir
3027
the value of a `CuConstantMemory` variable in this kernel between calls, please refer
3128
to [`Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::HostKernel)`](@ref)
3229
"""
33-
struct CuConstantMemory{T,N} <: AbstractArray{T,N}
34-
name::Symbol
35-
value::Array{T,N}
30+
mutable struct CuConstantMemory{T,N} <: AbstractArray{T,N}
31+
name::String
32+
size::Dims{N}
33+
value::Union{Nothing,Array{T,N}}
34+
35+
function CuConstantMemory(value::Array{T,N}; name::String) where {T,N}
36+
Base.isbitstype(T) || throw(ArgumentError("CuConstantMemory only supports bits types"))
37+
return new{T,N}(GPUCompiler.safe_name("constant_" * name), size(value), deepcopy(value))
38+
end
3639

37-
function CuConstantMemory(value::Array{T,N}) where {T,N}
38-
# TODO: add finalizer that removes the relevant entry from constant_memory_initializer?
40+
function CuConstantMemory(::UndefInitializer, dims::Dims{N}; name::String) where {T,N}
3941
Base.isbitstype(T) || throw(ArgumentError("CuConstantMemory only supports bits types"))
40-
name = gensym("constant_memory")
41-
name = GPUCompiler.safe_name(string(name))
42-
name = Symbol(name)
43-
val = deepcopy(value)
44-
constant_memory_initializer[name] = WeakRef(val)
45-
return new{T,N}(name, val)
42+
return new{T,N}(GPUCompiler.safe_name("constant_" * name), dims, nothing)
4643
end
4744
end
4845

49-
CuConstantMemory{T}(::UndefInitializer, dims::Integer...) where {T} =
50-
CuConstantMemory(Array{T}(undef, dims))
51-
CuConstantMemory{T}(::UndefInitializer, dims::Dims{N}) where {T,N} =
52-
CuConstantMemory(Array{T,N}(undef, dims))
46+
CuConstantMemory{T}(::UndefInitializer, dims::Integer...; kwargs...) where {T} =
47+
CuConstantMemory(Array{T}(undef, dims); kwargs...)
48+
CuConstantMemory{T}(::UndefInitializer, dims::Dims{N}; kwargs...) where {T,N} =
49+
CuConstantMemory{T,N}(undef, dims; kwargs...)
5350

54-
Base.size(A::CuConstantMemory) = size(A.value)
51+
Base.size(A::CuConstantMemory) = A.size
5552

5653
Base.getindex(A::CuConstantMemory, i::Integer) = Base.getindex(A.value, i)
5754
Base.setindex!(A::CuConstantMemory, v, i::Integer) = Base.setindex!(A.value, v, i)
5855
Base.IndexStyle(::Type{<:CuConstantMemory}) = Base.IndexLinear()
5956

60-
Adapt.adapt_storage(::Adaptor, A::CuConstantMemory{T,N}) where {T,N} =
61-
CuDeviceConstantMemory{T,N,A.name,size(A.value)}()
57+
function Adapt.adapt_storage(::Adaptor, A::CuConstantMemory{T,N}) where {T,N}
58+
# convert the values to the type domain
59+
# XXX: this is tough on the compiler when dealing with large initializers.
60+
typevals = if A.value !== nothing
61+
Tuple(reshape(A.value, prod(A.size)))
62+
else
63+
nothing
64+
end
65+
66+
CuDeviceConstantMemory{T,N,Symbol(A.name),A.size,typevals}()
67+
end
6268

6369

6470
"""
@@ -74,64 +80,3 @@ function Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::H
7480
global_array = CuGlobalArray{T}(kernel.mod, string(const_mem.name), length(const_mem))
7581
copyto!(global_array, value)
7682
end
77-
78-
79-
function emit_constant_memory_initializer!(mod::LLVM.Module)
80-
for global_var in globals(mod)
81-
T_global = llvmtype(global_var)
82-
if addrspace(T_global) == AS.Constant
83-
constant_memory_name = Symbol(LLVM.name(global_var))
84-
if !haskey(constant_memory_initializer, constant_memory_name)
85-
continue # non user defined constant memory, most likely from the CUDA runtime
86-
end
87-
88-
arr = constant_memory_initializer[constant_memory_name].value
89-
@assert !isnothing(arr) "calling kernel containing garbage collected constant memory"
90-
91-
flattened_arr = reduce(vcat, arr)
92-
ctx = LLVM.context(mod)
93-
typ = eltype(eltype(T_global))
94-
95-
# TODO: have a look at how julia converts structs to llvm:
96-
# https://github.com/JuliaLang/julia/blob/80ace52b03d9476f3d3e6ff6da42f04a8df1cf7b/src/cgutils.cpp#L572
97-
# this only seems to emit a type though
98-
if isa(typ, LLVM.IntegerType) || isa(typ, LLVM.FloatingPointType)
99-
init = ConstantArray(flattened_arr, ctx)
100-
elseif isa(typ, LLVM.ArrayType) # a struct with every field of the same type gets optimized to an array
101-
constant_arrays = LLVM.Constant[]
102-
for x in flattened_arr
103-
fields = collect(map(name->getfield(x, name), fieldnames(typeof(x))))
104-
constant_array = ConstantArray(fields, ctx)
105-
push!(constant_arrays, constant_array)
106-
end
107-
init = ConstantArray(typ, constant_arrays)
108-
elseif isa(typ, LLVM.StructType)
109-
constant_structs = LLVM.Constant[]
110-
for x in flattened_arr
111-
constants = LLVM.Constant[]
112-
for fieldname in fieldnames(typeof(x))
113-
field = getfield(x, fieldname)
114-
if isa(field, Bool)
115-
# NOTE: Bools get compiled to i8 instead of the more "correct" type i1
116-
push!(constants, ConstantInt(LLVM.Int8Type(ctx), field))
117-
elseif isa(field, Integer)
118-
push!(constants, ConstantInt(field, ctx))
119-
elseif isa(field, AbstractFloat)
120-
push!(constants, ConstantFP(field, ctx))
121-
else
122-
throw(error("constant memory does not currently support structs with non-primitive fields ($(typeof(x)).$fieldname::$(typeof(field)))"))
123-
end
124-
end
125-
const_struct = ConstantStruct(typ, constants)
126-
push!(constant_structs, const_struct)
127-
end
128-
init = ConstantArray(typ, constant_structs)
129-
else
130-
# unreachable, but let's be safe and throw a nice error message just in case
131-
throw(error("could not emit initializer for constant memory of type $typ"))
132-
end
133-
134-
initializer!(global_var, init)
135-
end
136-
end
137-
end

0 commit comments

Comments
 (0)