1
1
export CuConstantMemory
2
2
3
- # Map a constant memory name to its array value
4
- const constant_memory_initializer = Dict {Symbol,WeakRef} ()
5
-
6
3
"""
7
4
CuConstantMemory{T,N}(value::Array{T,N})
8
5
CuConstantMemory{T}(::UndefInitializer, dims::Integer...)
@@ -30,35 +27,44 @@ In cases where the same kernel object gets called mutiple times, and it is desir
30
27
the value of a `CuConstantMemory` variable in this kernel between calls, please refer
31
28
to [`Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::HostKernel)`](@ref)
32
29
"""
33
- struct CuConstantMemory{T,N} <: AbstractArray{T,N}
34
- name:: Symbol
35
- value:: Array{T,N}
30
+ mutable struct CuConstantMemory{T,N} <: AbstractArray{T,N}
31
+ name:: String
32
+ size:: Dims{N}
33
+ value:: Union{Nothing,Array{T,N}}
34
+
35
+ function CuConstantMemory (value:: Array{T,N} ; name:: String ) where {T,N}
36
+ Base. isbitstype (T) || throw (ArgumentError (" CuConstantMemory only supports bits types" ))
37
+ return new {T,N} (GPUCompiler. safe_name (" constant_" * name), size (value), deepcopy (value))
38
+ end
36
39
37
- function CuConstantMemory (value:: Array{T,N} ) where {T,N}
38
- # TODO : add finalizer that removes the relevant entry from constant_memory_initializer?
40
+ function CuConstantMemory (:: UndefInitializer , dims:: Dims{N} ; name:: String ) where {T,N}
39
41
Base. isbitstype (T) || throw (ArgumentError (" CuConstantMemory only supports bits types" ))
40
- name = gensym (" constant_memory" )
41
- name = GPUCompiler. safe_name (string (name))
42
- name = Symbol (name)
43
- val = deepcopy (value)
44
- constant_memory_initializer[name] = WeakRef (val)
45
- return new {T,N} (name, val)
42
+ return new {T,N} (GPUCompiler. safe_name (" constant_" * name), dims, nothing )
46
43
end
47
44
end
48
45
49
- CuConstantMemory {T} (:: UndefInitializer , dims:: Integer... ) where {T} =
50
- CuConstantMemory (Array {T} (undef, dims))
51
- CuConstantMemory {T} (:: UndefInitializer , dims:: Dims{N} ) where {T,N} =
52
- CuConstantMemory ( Array {T,N} (undef, dims) )
46
+ CuConstantMemory {T} (:: UndefInitializer , dims:: Integer... ; kwargs ... ) where {T} =
47
+ CuConstantMemory (Array {T} (undef, dims); kwargs ... )
48
+ CuConstantMemory {T} (:: UndefInitializer , dims:: Dims{N} ; kwargs ... ) where {T,N} =
49
+ CuConstantMemory {T,N} (undef, dims; kwargs ... )
53
50
54
- Base. size (A:: CuConstantMemory ) = size (A . value)
51
+ Base. size (A:: CuConstantMemory ) = A . size
55
52
56
53
Base. getindex (A:: CuConstantMemory , i:: Integer ) = Base. getindex (A. value, i)
57
54
Base. setindex! (A:: CuConstantMemory , v, i:: Integer ) = Base. setindex! (A. value, v, i)
58
55
Base. IndexStyle (:: Type{<:CuConstantMemory} ) = Base. IndexLinear ()
59
56
60
- Adapt. adapt_storage (:: Adaptor , A:: CuConstantMemory{T,N} ) where {T,N} =
61
- CuDeviceConstantMemory {T,N,A.name,size(A.value)} ()
57
+ function Adapt. adapt_storage (:: Adaptor , A:: CuConstantMemory{T,N} ) where {T,N}
58
+ # convert the values to the type domain
59
+ # XXX : this is tough on the compiler when dealing with large initializers.
60
+ typevals = if A. value != = nothing
61
+ Tuple (reshape (A. value, prod (A. size)))
62
+ else
63
+ nothing
64
+ end
65
+
66
+ CuDeviceConstantMemory {T,N,Symbol(A.name),A.size,typevals} ()
67
+ end
62
68
63
69
64
70
"""
@@ -74,64 +80,3 @@ function Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::H
74
80
global_array = CuGlobalArray {T} (kernel. mod, string (const_mem. name), length (const_mem))
75
81
copyto! (global_array, value)
76
82
end
77
-
78
-
79
- function emit_constant_memory_initializer! (mod:: LLVM.Module )
80
- for global_var in globals (mod)
81
- T_global = llvmtype (global_var)
82
- if addrspace (T_global) == AS. Constant
83
- constant_memory_name = Symbol (LLVM. name (global_var))
84
- if ! haskey (constant_memory_initializer, constant_memory_name)
85
- continue # non user defined constant memory, most likely from the CUDA runtime
86
- end
87
-
88
- arr = constant_memory_initializer[constant_memory_name]. value
89
- @assert ! isnothing (arr) " calling kernel containing garbage collected constant memory"
90
-
91
- flattened_arr = reduce (vcat, arr)
92
- ctx = LLVM. context (mod)
93
- typ = eltype (eltype (T_global))
94
-
95
- # TODO : have a look at how julia converts structs to llvm:
96
- # https://github.com/JuliaLang/julia/blob/80ace52b03d9476f3d3e6ff6da42f04a8df1cf7b/src/cgutils.cpp#L572
97
- # this only seems to emit a type though
98
- if isa (typ, LLVM. IntegerType) || isa (typ, LLVM. FloatingPointType)
99
- init = ConstantArray (flattened_arr, ctx)
100
- elseif isa (typ, LLVM. ArrayType) # a struct with every field of the same type gets optimized to an array
101
- constant_arrays = LLVM. Constant[]
102
- for x in flattened_arr
103
- fields = collect (map (name-> getfield (x, name), fieldnames (typeof (x))))
104
- constant_array = ConstantArray (fields, ctx)
105
- push! (constant_arrays, constant_array)
106
- end
107
- init = ConstantArray (typ, constant_arrays)
108
- elseif isa (typ, LLVM. StructType)
109
- constant_structs = LLVM. Constant[]
110
- for x in flattened_arr
111
- constants = LLVM. Constant[]
112
- for fieldname in fieldnames (typeof (x))
113
- field = getfield (x, fieldname)
114
- if isa (field, Bool)
115
- # NOTE: Bools get compiled to i8 instead of the more "correct" type i1
116
- push! (constants, ConstantInt (LLVM. Int8Type (ctx), field))
117
- elseif isa (field, Integer)
118
- push! (constants, ConstantInt (field, ctx))
119
- elseif isa (field, AbstractFloat)
120
- push! (constants, ConstantFP (field, ctx))
121
- else
122
- throw (error (" constant memory does not currently support structs with non-primitive fields ($(typeof (x)) .$fieldname ::$(typeof (field)) )" ))
123
- end
124
- end
125
- const_struct = ConstantStruct (typ, constants)
126
- push! (constant_structs, const_struct)
127
- end
128
- init = ConstantArray (typ, constant_structs)
129
- else
130
- # unreachable, but let's be safe and throw a nice error message just in case
131
- throw (error (" could not emit initializer for constant memory of type $typ " ))
132
- end
133
-
134
- initializer! (global_var, init)
135
- end
136
- end
137
- end
0 commit comments