Closed
Description
I am not sure if this is in the works already, but it would be great to have support for keyword arguments in @non_differentiable
.
We encountered a related bug with isapprox
over in DistributionsAD due to https://github.com/JuliaDiff/ChainRules.jl/blob/c8679c6652eff4deb7ff075d87d91e876842ae59/src/rulesets/Base/nondiff.jl#L123
MWE below
using Zygote
function g(x)
isapprox(x, x, atol=1e-5)
sum(x)
end
function f(x)
isapprox(x, x)
sum(x)
end
display(Zygote.gradient(f, rand(3)))
display(Zygote.gradient(g, rand(3)))
results in
(3-element Fill{Float64}: entries equal to 1.0,)
ERROR: LoadError: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
iterate(::Base.RegexMatchIterator) at regex.jl:552
iterate(::Base.RegexMatchIterator, ::Any) at regex.jl:552
iterate(::LibGit2.GitBranchIter) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LibGit2/src/reference.jl:343
...
Stacktrace:
[1] indexed_iterate(::Nothing, ::Int64) at ./tuple.jl:84
[2] chain_rrule_kw at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/chainrules.jl:101 [inlined]
[3] macro expansion at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface2.jl:0 [inlined]
[4] _pullback(::Zygote.Context, ::Base.var"#isapprox##kw", ::NamedTuple{(:atol,),Tuple{Float64}}, ::typeof(isapprox), ::Array{Float64,1}, ::Array{Float64,1}) at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface2.jl:12
[5] g at /home/niklas/.julia/dev/ChainRules/mwe.jl:4 [inlined]
[6] _pullback(::Zygote.Context, ::typeof(g), ::Array{Float64,1}) at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface2.jl:0
[7] _pullback(::Function, ::Array{Float64,1}) at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface.jl:38
[8] pullback(::Function, ::Array{Float64,1}) at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface.jl:44
[9] gradient(::Function, ::Array{Float64,1}) at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface.jl:53
[10] top-level scope at /home/niklas/.julia/dev/ChainRules/mwe.jl:14
[11] include(::String) at ./client.jl:457
[12] top-level scope at REPL[1]:1
in expression starting at /home/niklas/.julia/dev/ChainRules/mwe.jl:14