Skip to content

NNlib.gather triggers InvalidIRError when doing Forward over Reverse for hessian computation on GPU #547

Open
@vpuri3

Description

vedantpu@eagle hess]:cat scat.jl 
#
using CUDA, NNlib
using Zygote, ForwardDiff

CUDA.allowscalar(false)

#==========================#
function hess_gather(
    x::AbstractMatrix,
    i::AbstractVector{<:Integer};
    ifgpu = false,
)
    if ifgpu
        x = x  |> cu
        i = i .|> Int32 |> cu
    end

    function loss(x)
        y = NNlib.gather(x, i)

        sum(abs2, y)
    end

    g(x) = Zygote.gradient(loss, x)[1]
    H(x) = ForwardDiff.jacobian(g, x)

    H(x)
end
#==========================#

E, O, K = 3, 5, 10
x = rand(O, E)
i = rand(1:E, K)

hess_gather(x, i; ifgpu = true)
#
ERROR: LoadError: InvalidIRError: compiling MethodInstance for NNlibCUDAExt.scatter_kernel!(::typeof(+), ::CuDeviceMatrix{…}, ::CuDeviceMatrix{…}, ::CuDeviceVector{…}, 
::Int64, ::Int64, ::Tuple{…}) resulted in invalid LLVM IR                                                                                                               
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)                                                                                            
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)                                                                                           
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)                                                                                            
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)                                                                                       
Reason: unsupported dynamic function invocation (call to atomic_cas!)                                                                                                   
Stacktrace:                                                                                                                                                             
 [1] atomic_op!                                                                     
   @ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:228                                                                                                  
 [2] atomic_arrayset                                                                                                                                                    
   @ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:468                                                                                                  
 [3] atomic_arrayset                                                                                                                                                    
   @ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:440
 [4] macro expansion                                                                
   @ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:435                                                                                                  
 [5] scatter_kernel!
   @ ~/.julia/packages/NNlib/5iRSB/ext/NNlibCUDAExt/scatter.jl:28         
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, args::LLVM.Module)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/validation.jl:147
  [2] macro expansion
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:440 [inlined]
  [3] macro expansion
    @ GPUCompiler ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
  [4] macro expansion
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:439 [inlined]
  [5] emit_llvm(job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, only_entry::Bool, validate::Bool)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:92
  [6] emit_llvm
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:86 [inlined]
  [7] 
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:129
  [8] codegen
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:110 [inlined]
  [9] compile(target::Symbol, job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, strip::Bool, validate::Bool, only_entry::Boo
l)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:106
 [10] compile
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:98 [inlined]
 [11] #1042
    @ GPUCompiler ~/.julia/packages/CUDA/nbRJk/src/compiler/compilation.jl:166 [inlined]
 [12] JuliaContext(f::CUDA.var"#1042#1045"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:47
 [13] compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~/.julia/packages/CUDA/nbRJk/src/compiler/compilation.jl:165
 [14] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUD
A.link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:125
 [15] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:103
 [16] macro expansion
    @ CUDA ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:323 [inlined]
 [17] macro expansion
    @ CUDA ./lock.jl:267 [inlined]
 [18] cufunction(f::typeof(NNlibCUDAExt.scatter_kernel!), tt::Type{Tuple{…}}; kwargs::@Kwargs{})
    @ CUDA ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:318
 [19] cufunction
    @ NNlibCUDAExt ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:315 [inlined]
 [20] macro expansion
    @ NNlibCUDAExt ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:104 [inlined]
 [21] scatter!(op::typeof(+), dst::CuArray{…}, src::CuArray{…}, idx::CuArray{…})
    @ NNlibCUDAExt ~/.julia/packages/NNlib/5iRSB/ext/NNlibCUDAExt/scatter.jl:58
 [22] ∇gather_src
    @ NNlib ~/.julia/packages/NNlib/5iRSB/src/gather.jl:131 [inlined]
 [23] gather!_pullback
    @ NNlib ~/.julia/packages/NNlib/5iRSB/src/gather.jl:136 [inlined]
 [24] ZBack
    @ Zygote ~/.julia/dev/Zygote/src/compiler/chainrules.jl:211 [inlined]
 [25] gather
    @ Zygote ~/.julia/packages/NNlib/5iRSB/src/gather.jl:46 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CuArray{ForwardDiff.Dual{…}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [27] loss
    @ Zygote ~/.julia/dev/GeometryLearning.jl/hess/scat.jl:19 [inlined]
...

with ifgpu = false,

julia> include("scat.jl")
15×15 Matrix{Float64}:
 6.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  6.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  6.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  6.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  6.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  8.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  8.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  8.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  8.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  8.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  6.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  6.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  6.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  6.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  6.0

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions