NNlib.gather triggers InvalidIRError
when doing Forward over Reverse for hessian computation on GPU #547
Open
Description
vedantpu@eagle hess]:cat scat.jl
#
using CUDA, NNlib
using Zygote, ForwardDiff
CUDA.allowscalar(false)
#==========================#
function hess_gather(
x::AbstractMatrix,
i::AbstractVector{<:Integer};
ifgpu = false,
)
if ifgpu
x = x |> cu
i = i .|> Int32 |> cu
end
function loss(x)
y = NNlib.gather(x, i)
sum(abs2, y)
end
g(x) = Zygote.gradient(loss, x)[1]
H(x) = ForwardDiff.jacobian(g, x)
H(x)
end
#==========================#
E, O, K = 3, 5, 10
x = rand(O, E)
i = rand(1:E, K)
hess_gather(x, i; ifgpu = true)
#
ERROR: LoadError: InvalidIRError: compiling MethodInstance for NNlibCUDAExt.scatter_kernel!(::typeof(+), ::CuDeviceMatrix{…}, ::CuDeviceMatrix{…}, ::CuDeviceVector{…},
::Int64, ::Int64, ::Tuple{…}) resulted in invalid LLVM IR
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)
Reason: unsupported dynamic function invocation (call to atomic_cas!)
Stacktrace:
[1] atomic_op!
@ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:228
[2] atomic_arrayset
@ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:468
[3] atomic_arrayset
@ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:440
[4] macro expansion
@ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:435
[5] scatter_kernel!
@ ~/.julia/packages/NNlib/5iRSB/ext/NNlibCUDAExt/scatter.jl:28
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
[1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, args::LLVM.Module)
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/validation.jl:147
[2] macro expansion
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:440 [inlined]
[3] macro expansion
@ GPUCompiler ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
[4] macro expansion
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:439 [inlined]
[5] emit_llvm(job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, only_entry::Bool, validate::Bool)
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:92
[6] emit_llvm
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:86 [inlined]
[7]
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:129
[8] codegen
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:110 [inlined]
[9] compile(target::Symbol, job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, strip::Bool, validate::Bool, only_entry::Boo
l)
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:106
[10] compile
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:98 [inlined]
[11] #1042
@ GPUCompiler ~/.julia/packages/CUDA/nbRJk/src/compiler/compilation.jl:166 [inlined]
[12] JuliaContext(f::CUDA.var"#1042#1045"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}})
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:47
[13] compile(job::GPUCompiler.CompilerJob)
@ CUDA ~/.julia/packages/CUDA/nbRJk/src/compiler/compilation.jl:165
[14] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUD
A.link))
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:125
[15] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:103
[16] macro expansion
@ CUDA ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:323 [inlined]
[17] macro expansion
@ CUDA ./lock.jl:267 [inlined]
[18] cufunction(f::typeof(NNlibCUDAExt.scatter_kernel!), tt::Type{Tuple{…}}; kwargs::@Kwargs{})
@ CUDA ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:318
[19] cufunction
@ NNlibCUDAExt ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:315 [inlined]
[20] macro expansion
@ NNlibCUDAExt ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:104 [inlined]
[21] scatter!(op::typeof(+), dst::CuArray{…}, src::CuArray{…}, idx::CuArray{…})
@ NNlibCUDAExt ~/.julia/packages/NNlib/5iRSB/ext/NNlibCUDAExt/scatter.jl:58
[22] ∇gather_src
@ NNlib ~/.julia/packages/NNlib/5iRSB/src/gather.jl:131 [inlined]
[23] gather!_pullback
@ NNlib ~/.julia/packages/NNlib/5iRSB/src/gather.jl:136 [inlined]
[24] ZBack
@ Zygote ~/.julia/dev/Zygote/src/compiler/chainrules.jl:211 [inlined]
[25] gather
@ Zygote ~/.julia/packages/NNlib/5iRSB/src/gather.jl:46 [inlined]
[26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CuArray{ForwardDiff.Dual{…}, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[27] loss
@ Zygote ~/.julia/dev/GeometryLearning.jl/hess/scat.jl:19 [inlined]
...
with ifgpu = false
,
julia> include("scat.jl")
15×15 Matrix{Float64}:
6.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 6.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 6.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 6.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 6.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 8.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 8.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 8.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 8.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 8.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 6.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 6.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 6.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 6.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 6.0