-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hessian of NN output wrt input (GPU) #1070
Comments
hmm, could we try with removing the warning at https://github.com/FluxML/Flux.jl/blob/c91867fd8cc49b19d90c88e0346db674cb757abd/src/utils.jl#L623 I would also try with an older Zygote version.. Additionally, the mwe might be able to be reduced further, removing much of NeuralPDE and DiffEqFlux and depend mostly on the |
We should probably toss an |
Would this be a MWE? using CUDA, Flux, Zygote
CUDA.allowscalar(false)
dim = 2; nn = 10;
chain = Chain(Dense(dim, nn, tanh), Dense(nn, nn, tanh), Dense(nn, 1));
θ, re = Flux.destructure(chain);
initθ_g = cu(θ); # parameter vector in gpu
phi_g = (x,θ) -> re(θ)(cu(x)) # nn function with gpu params
initθ_c = Array(θ); # parameter vector in cpu
phi_c = (x,θ) -> re(θ)(Array(x)) # function with cpu params
ρFn_c(y) = sum(phi_c(y, initθ_c)); # phi_c returns the NN output (cpu)
ρFn_g(y) = sum(phi_g(y, initθ_g)); # phi_g returns the NN output (GPU)
xg = rand(2) |> gpu;
xc = Array(xg);
# @show ρFn_c(xc) # works
# @show ρFn_g(xg) # works
## gradient wrt input
dρ_g(x) = Zygote.gradient(ρFn_g, x)[1];
# @show dρ_g(xg) # works
dρ_c(x) = Zygote.gradient(ρFn_c, x)[1];
# @show dρ_c(xc) # works
## hessian wrt input
d2ρ_c(x) = Zygote.hessian(ρFn_c, x);
# d2ρ_c(xc) # works
d2ρ_g(x) = Zygote.hessian(ρFn_g, x);
# d2ρ_g(xg) # doesn't work (forward over reverse)
# Zygote.hessian_reverse(ρFn_g, xg) # doesn't work (reverse over reverse) The errors are the same as before. |
The error for julia> hessian(x -> sum(tanh.(x)), [1,2,3.4]) # CPU, ok
3×3 Matrix{Float64}:
-0.6397 0.0 0.0
0.0 -0.136219 0.0
0.0 0.0 -0.0088706
julia> hessian(x -> sum(tanh.(x)), cu([1,2,3.4])) # GPU, same error as above
ERROR: MethodError: no method matching cudnnDataType(::Type{ForwardDiff.Dual{Nothing, Float32, 3}})
julia> gradient(x -> sum(tanh.(x)), [1,2,3.4])
([0.41997434161402614, 0.07065082485316443, 0.004445193185743657],)
julia> gradient(x -> sum(tanh.(x)), cu([1,2,3.4])) # 1st derivative works fine
(Float32[0.4199743, 0.070650816, 0.004445251],)
(@v1.7) pkg> st
Status `~/.julia/environments/v1.7/Project.toml`
[052768ef] CUDA v3.4.2
[587475ba] Flux v0.12.6
[e88e6eb3] Zygote v0.6.21 Second derivatives Zygote over Zygote should cause it to use a (slower) generic method which I believe is usually 2nd differentiable, but won't work on the GPU. This is also where the two julia> Zygote.hessian_reverse(x -> sum(tanh.(x)), [1,2,3.4]) # reverse over reverse, CPU
3×3 Matrix{Float64}:
-0.6397 -0.0 -0.0
-0.0 -0.136219 -0.0
-0.0 -0.0 -0.0088706
julia> Zygote.hessian_reverse(x -> sum(tanh.(x)), cu([1,2,3.4]))
ERROR: Mutating arrays is not supported -- called copyto!(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, _...)
julia> Zygote.hessian_reverse(ρFn_c, xc) # above MWE, reverse over reverse, CPU
ERROR: Can't differentiate foreigncall expression
julia> Zygote.hessian_reverse(ρFn_g, xg) # similar error on GPU
ERROR: Can't differentiate foreigncall expression I don't understand from the error how broadcasting is going wrong. Maybe it's possible to mess around with overloading Maybe it's possible to use a different 2 AD packages. I know some people to ReverseDiff over Zygote or the reverse. julia> Zygote.hessian_reverse(ρFn_c, xc)
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] Pullback
@ ./iddict.jl:102 [inlined]
[3] (::typeof(∂(get)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[4] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:68 [inlined]
[5] (::typeof(∂(accum_global)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:79 [inlined]
[7] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[9] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[10] gradtuple1
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:22 [inlined]
[11] #1640#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[12] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float32}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41 [inlined]
[14] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float32}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:76 [inlined]
[16] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float32}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[17] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:87 [inlined]
[18] (::typeof(∂(#107)))(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[19] (::Zygote.var"#219#220"{Tuple{Tuple{Nothing}}, typeof(∂(#107))})(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:203
[20] (::Zygote.var"#1760#back#221"{Zygote.var"#219#220"{Tuple{Tuple{Nothing}}, typeof(∂(#107))}})(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[21] Pullback
@ ./operators.jl:1085 [inlined]
[22] (::typeof(∂(#_#83)))(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[23] (::Zygote.var"#219#220"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#83))})(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:203
[24] #1760#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[25] Pullback
@ ./operators.jl:1085 [inlined]
[26] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), Zygote.var"#107#108"{typeof(ρFn_c)}}(Zygote._jvec, Zygote.var"#107#108"{typeof(ρFn_c)}(ρFn_c)))))(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[27] (::Zygote.var"#52#53"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), Zygote.var"#107#108"{typeof(ρFn_c)}}(Zygote._jvec, Zygote.var"#107#108"{typeof(ρFn_c)}(ρFn_c))))})(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
[28] withjacobian(f::Function, args::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:162
[29] jacobian
@ ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:140 [inlined]
[30] hessian_reverse(f::Function, x::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:87
[31] top-level scope
@ REPL[42]:1
[32] top-level scope
@ ~/.julia/packages/CUDA/9T5Sq/src/initialization.jl:66
julia> Zygote.hessian_reverse(ρFn_g, xg)
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] Pullback
@ ./iddict.jl:102 [inlined]
[3] (::typeof(∂(get)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[4] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:68 [inlined]
[5] (::typeof(∂(accum_global)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:79 [inlined]
[7] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[9] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[10] gradtuple1
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:22 [inlined]
[11] #1640#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[12] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41 [inlined]
[14] (::typeof(∂(λ)))(Δ::Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:76 [inlined]
[16] (::typeof(∂(gradient)))(Δ::Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[17] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:87 [inlined]
[18] (::typeof(∂(#107)))(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[19] (::Zygote.var"#219#220"{Tuple{Tuple{Nothing}}, typeof(∂(#107))})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:203
[20] (::Zygote.var"#1760#back#221"{Zygote.var"#219#220"{Tuple{Tuple{Nothing}}, typeof(∂(#107))}})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[21] Pullback
@ ./operators.jl:1085 [inlined]
[22] (::typeof(∂(#_#83)))(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[23] (::Zygote.var"#219#220"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#83))})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:203
[24] #1760#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[25] Pullback
@ ./operators.jl:1085 [inlined]
[26] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), Zygote.var"#107#108"{typeof(ρFn_g)}}(Zygote._jvec, Zygote.var"#107#108"{typeof(ρFn_g)}(ρFn_g)))))(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[27] (::Zygote.var"#52#53"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), Zygote.var"#107#108"{typeof(ρFn_g)}}(Zygote._jvec, Zygote.var"#107#108"{typeof(ρFn_g)}(ρFn_g))))})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
[28] withjacobian(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:162
[29] jacobian
@ ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:140 [inlined]
[30] hessian_reverse(f::Function, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:87
[31] top-level scope
@ REPL[44]:1
[32] top-level scope
@ ~/.julia/packages/CUDA/9T5Sq/src/initialization.jl:66 |
Hello,
I need to compute the hessian of a neural network's output with respect to its inputs. I am able to do it on the CPU without any issues, but I'd like to do it on the GPU.
Neither
Zygote.hessian
nor 'Zygote.hessian_reverse' work for me. Could you please help sort this out?Attaching a simple example:
The error for
Zygote.hessian
is:and for
Zygote.hessian_reverse
is:Please let me know if I'm doing something wrong, or if there's an alternative solution to obtaining the hessian on the GPU.
The text was updated successfully, but these errors were encountered: