Skip to content

Constant memory support #552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft

Conversation

S-D-R
Copy link
Contributor

@S-D-R S-D-R commented Nov 17, 2020

This PR adds support for constant memory. A non-exhaustive list of stuff to think about and/or work on:

  • Should constant memory be allowed to be defined in the local scope? Currently this is supported (cf. the tests), but requires that CuConstantMemory is an isbits type, which in turn makes initialisation more convoluted, due to not being able to store an array field . Alternatively some compiler work can be done to allow non-isbits types as CUDA kernel arguments.
  • Initialisation could be improved, if not from a design perspective, then at least by using something like a WeakKeyDict to not needlessly store CuConstantMemory that should have been GC'd
  • Look at the generated LLVM/PTX code for more complex use cases, making sure the compiler is generating the right code
  • Add some performance benchmarks
  • More tests, as always
  • Documentation once the final API has been decided

@maleadt maleadt added cuda kernels Stuff about writing CUDA kernels. enhancement New feature or request labels Nov 17, 2020
@maleadt
Copy link
Member

maleadt commented Nov 17, 2020

Should constant memory be allowed to be defined in the local scope? Currently this is supported (cf. the tests), but requires that CuConstantMemory is an isbits type, which in turn makes initialisation more convoluted, due to not being able to store an array field . Alternatively some compiler work can be done to allow non-isbits types as CUDA kernel arguments.

One alternative is to create a device-side isbits type, like CuDeviceArray, and use Adapt/cudaconvert to automatically convert from the host to device type at the kernel launch point.

@codecov
Copy link

codecov bot commented Nov 17, 2020

Codecov Report

Attention: Patch coverage is 65.93407% with 31 lines in your changes missing coverage. Please review.

Project coverage is 79.56%. Comparing base (4eb99b9) to head (cec71b0).
Report is 1966 commits behind head on master.

Files with missing lines Patch % Lines
src/memory_global.jl 0.00% 18 Missing ⚠️
lib/cudadrv/module/global.jl 57.89% 8 Missing ⚠️
src/memory_constant.jl 87.50% 3 Missing ⚠️
src/compiler/execution.jl 93.10% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #552      +/-   ##
==========================================
- Coverage   79.69%   79.56%   -0.13%     
==========================================
  Files         122      124       +2     
  Lines        7356     7429      +73     
==========================================
+ Hits         5862     5911      +49     
- Misses       1494     1518      +24     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@maleadt
Copy link
Member

maleadt commented Nov 24, 2020

One alternative is to create a device-side isbits type, like CuDeviceArray, and use Adapt/cudaconvert to automatically convert from the host to device type at the kernel launch point.

I realized that only works for arguments, the conversion machinery doesn't get called for globals or captured variables (#67).

Concerning initialization, we definitely don't want to do this on every launch. Modules are cached, so the values should persist (if we don't consider global device memory for now). That means we only need to initialize after compiling a module, or when re-initializing. But to do so we need both mappings (module->gv for the initial compilation, gv->module for re-initialization).

Let's maybe keep it simpler for now, requiring either a literal value in the constant constructor (in which case you can inline it when creating the GV, so that you don't even need an external initializer), or that the user does the memcpyToSymbol to the module (if a little more user friendly, using the CuConstantMemory object). For that, you could use the cufunction 'trick' we do when we we need to introspect the CuFunction (it contains the parent CuModule). Maybe we should improve the ergonomics of that though, e.g., using @cuda delayed=true ... or something.

@S-D-R
Copy link
Contributor Author

S-D-R commented Nov 24, 2020

Another way of doing initialisation I've been thinking about which retains all of the current functionality without some of the overhead:

  1. Create an array A of symbols which keeps track of the names of all constant memory that needs to initialised during the next call to initialize_constant_memory.
  2. Change constant_memory_init_dict to be a Dict{Symbol,Array}, which maps a constant memory name to its value
  3. In GPUCompiler.finish_module! loop through all globals in the LLVM.Module. If any constant memory is found (the global is in address space 4), add its name to A.
  4. In initialize_constant_memory , simply loop through A and initialize by utilising constant_memory_init_dict to find the value of the constant memory by its name. Clear A after we're done initialising.

The only issue I see with this approach is that there's not really a way of freeing the memory inside of constant_memory_init_dict. There's an argument to be made that this doesn't really matter for constant memory, since there can only be 64k of constant memory per CuModule. So even if you create 1000 CuConstantMemory objects of 64k each, that's still "only" 64mb of wasted memory. This does however become a problem if we want to implement __device__ memory in a similar way.

@maleadt
Copy link
Member

maleadt commented Nov 24, 2020

Global state is annoying (esp. in the presence of multiple threads, tasks, devices), so I'd rather not introduce it if we can avoid it. I implemented the more convenient @cuda delayed=... argument in #569, so maybe try a more imperative, CUDA C-like API first? What about:

c = CuConstantMemory{Int}([42])
# const initializer put in the LLVM IR

function kernel(...)
    c[...]
end

@cuda kernel(...)
c = CuConstantMemory{Int}(undef, (1,))
# external initializer

function kernel(...)
    c[...]
end

kernel_obj = @cuda delayed=true kernel(...)
constant_memory(kernel_obj)[c] = [42] # or some other function, doesn't really matter
kernel_obj(...)

@maleadt maleadt force-pushed the master branch 2 times, most recently from 91db6b0 to 06fe10b Compare January 8, 2021 12:11
@maleadt
Copy link
Member

maleadt commented Jan 25, 2021

Squashed and rebased.

arr = constant_memory_initializer[constant_memory_name].value
@assert !isnothing(arr) "calling kernel containing garbage collected constant memory"

flattened_arr = reduce(vcat, arr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> reduce(vcat, [1])
1

...

ERROR: MethodError: no method matching LLVM.ConstantArray(::Int32, ::LLVM.Context)
Closest candidates are:
  LLVM.ConstantArray(::AbstractArray{Float64, N}, ::LLVM.Context) where N at /home/tim/Julia/pkg/LLVM/src/core/value/constant.jl:163
  LLVM.ConstantArray(::AbstractArray{Float32, N}, ::LLVM.Context) where N at /home/tim/Julia/pkg/LLVM/src/core/value/constant.jl:161
  LLVM.ConstantArray(::AbstractArray{Float16, N}, ::LLVM.Context) where N at /home/tim/Julia/pkg/LLVM/src/core/value/constant.jl:159

@maleadt
Copy link
Member

maleadt commented Jan 25, 2021

I pushed a WIP commit illustrating what I meant by eagerly initializing to avoid initializing as part of the compiler: This makes it so that LLVM can optimize given the actual constant values. There's an issue though: Julia discards the linkage when importing these variables via llvmcall:

function main()
    Base.llvmcall(
        ("""@constant_memory = addrspace(4) externally_initialized global [1 x i32] [i32 42]
            define void @entry() {
                ret void
            }
         """, "entry"), Nothing, Tuple{})
end

main()


##

using InteractiveUtils
@code_llvm dump_module=true main()


##

using LLVM

# get the method instance
world = Base.get_world_counter()
meth = which(main, Tuple{})
sig = Base.signature_type(main, Tuple{})::Type
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
                    (Any, Any), sig, meth.sig)::Core.SimpleVector
meth = Base.func_for_method_checked(meth, ti, env)
method_instance = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
                (Any, Any, Any, UInt), meth, ti, env, world)

# set-up the compiler interface
params = Base.CodegenParams()

# generate IR
native_code = ccall(:jl_create_native, Ptr{Cvoid},
                    (Vector{Core.MethodInstance}, Base.CodegenParams, Cint),
                    [method_instance], params, #=extern policy=# 1)
@assert native_code != C_NULL
llvm_mod_ref = ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
                        (Ptr{Cvoid},), native_code)
@assert llvm_mod_ref != C_NULL
llvm_mod = LLVM.Module(llvm_mod_ref)
println(llvm_mod)

The (implicitly) external symbol becomes internal in the second case. Because of that, the problematic SRA optimization kicks in.

EDIT: ah, https://github.com/JuliaLang/julia/blob/dbaca8ba16d406004c53cb211b9eaf8028f6b6be/src/aotcompile.cpp#L409

Anyway, once this is solved the approach is a bit better, I think. We could decide to make non-undef ConstantArrays not externally visible then, which would make the SRA transformation legal. I'm not sure whether we should keep them externally initialized: if not, it would also be legal to inline the constants in the function, but in that case the constant memory hardware wouldn't be used anymore. But maybe that's an improvement when LLVM can deduce the indices anyway.

Finally, we could get rid of the initializer map by putting the values in the device counterpart's type (as a Tuple(reshape(prod(size(vals))))) but that's just very taxing on the compiler. Better keep it a weakref like that.


function CuConstantMemory(value::Array{T,N}) where {T,N}
Base.isbitstype(T) || throw(ArgumentError("CuConstantMemory only supports bits types"))
name = gensym("constant_memory")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A problem with this is that every invocation of a function that launches a kernel with constant memory, will result in a recompilation (because the name is tied to the instance).

@S-D-R
Copy link
Contributor Author

S-D-R commented Jan 25, 2021

Anyway, once this is solved the approach is a bit better, I think. We could decide to make non-undef ConstantArrays not externally visible then, which would make the SRA transformation legal. I'm not sure whether we should keep them externally initialized: if not, it would also be legal to inline the constants in the function, but in that case the constant memory hardware wouldn't be used anymore. But maybe that's an improvement when LLVM can deduce the indices anyway.

Based on empirical evidence from playing around with CUDA C code, nvcc doesn't seem to inline constant (or global fwiw) memory, even in very simple cases (see also this SO question). This is most likely due to the difficulty of statically determining whether memcpy will ever be called on a certain global variable in C. As you mention inlining will probably lead to better performance in most use cases, but some users might not like this when they explicitly ask for constant memory. I suppose a solution to this would be adding an allow_inlining=true constructor argument to CuConstantMemory.

@maleadt
Copy link
Member

maleadt commented Jan 25, 2021

Did some more hacking, but I'm not entirely happy with the result yet. Previously, you gensymmed a name for every ConstantArray, which would result in a new kernel getting compiled. That's not an option, as it'll kill launch performance. But the alternative, not generating a name but looking at the contents, could result in two identical looking ConstantMemory, so you can't discern those (e.g. for a memcpy afterwards). So I guess the only solution for this kind of lexical identification is to pass a name argument, possibly auto-generated by a @cuConstantMemory macro? Not sure I like this.

I also tried getting rid of the global map in favor of passing the initializer as Val, which is tough on the compiler but fixes an issue where (after switching to a name param) re-using a name would result in collisions in the global map.

So I might just revert all that and go back to your design, but I wanted to try somethings out first :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda kernels Stuff about writing CUDA kernels. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants