Skip to content

Keyword arguments for @non_differentiable #216

Closed
FluxML/Zygote.jl
#788
@nmheim

Description

@nmheim

I am not sure if this is in the works already, but it would be great to have support for keyword arguments in @non_differentiable.
We encountered a related bug with isapprox over in DistributionsAD due to https://github.com/JuliaDiff/ChainRules.jl/blob/c8679c6652eff4deb7ff075d87d91e876842ae59/src/rulesets/Base/nondiff.jl#L123
MWE below

using Zygote

function g(x)
    isapprox(x, x, atol=1e-5)
    sum(x)
end

function f(x)
    isapprox(x, x)
    sum(x)
end

display(Zygote.gradient(f, rand(3)))
display(Zygote.gradient(g, rand(3)))

results in

(3-element Fill{Float64}: entries equal to 1.0,)
ERROR: LoadError: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Base.RegexMatchIterator) at regex.jl:552
  iterate(::Base.RegexMatchIterator, ::Any) at regex.jl:552
  iterate(::LibGit2.GitBranchIter) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LibGit2/src/reference.jl:343
  ...
Stacktrace:
 [1] indexed_iterate(::Nothing, ::Int64) at ./tuple.jl:84
 [2] chain_rrule_kw at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/chainrules.jl:101 [inlined]
 [3] macro expansion at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface2.jl:0 [inlined]
 [4] _pullback(::Zygote.Context, ::Base.var"#isapprox##kw", ::NamedTuple{(:atol,),Tuple{Float64}}, ::typeof(isapprox), ::Array{Float64,1}, ::Array{Float64,1}) at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface2.jl:12
 [5] g at /home/niklas/.julia/dev/ChainRules/mwe.jl:4 [inlined]
 [6] _pullback(::Zygote.Context, ::typeof(g), ::Array{Float64,1}) at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface2.jl:0
 [7] _pullback(::Function, ::Array{Float64,1}) at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface.jl:38
 [8] pullback(::Function, ::Array{Float64,1}) at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface.jl:44
 [9] gradient(::Function, ::Array{Float64,1}) at /home/niklas/.julia/packages/Zygote/Xgcgs/src/compiler/interface.jl:53
 [10] top-level scope at /home/niklas/.julia/dev/ChainRules/mwe.jl:14
 [11] include(::String) at ./client.jl:457
 [12] top-level scope at REPL[1]:1
in expression starting at /home/niklas/.julia/dev/ChainRules/mwe.jl:14

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions