Description
Calling frule_via_ad(cfg, (NoTangent(), one(x)), f, x)
to work out the derivative works for numbers but not in general. So this path:
fails if broadcasting over an array of Ref
, or almost any struct
:
julia> using ChainRules, ChainRulesTestUtils
julia> CFG = ChainRulesTestUtils.TestConfig();
julia> ChainRules.split_bc_forwards(CFG, only, [Ref(1.0), Ref(2.0)])
ERROR: MethodError: no method matching one(::Base.RefValue{Float64})
Stacktrace:
[1] (::ChainRules.var"#1732#1734"{typeof(frule_via_ad), ChainRulesTestUtils.TestConfig, typeof(only)})(a::Base.RefValue{Float64})
@ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:109
...
[10] StructArray
@ ~/.julia/packages/StructArrays/dNQpc/src/structarray.jl:254 [inlined]
[11] unzip_broadcast(f::ChainRules.var"#1732#1734"{typeof(frule_via_ad), ChainRulesTestUtils.TestConfig, typeof(only)}, args::Vector{Base.RefValue{Float64}})
@ ChainRules ~/.julia/dev/ChainRules/src/unzipped.jl:40
[12] split_bc_inner
@ ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:108 [inlined]
[13] split_bc_forwards(cfg::ChainRulesTestUtils.TestConfig, f::typeof(only), arg::Vector{Base.RefValue{Float64}})
julia> struct A594 x::Float64 end;
julia> ChainRules.split_bc_forwards(CFG, x -> x.x, A594.(1:3))
ERROR: MethodError: no method matching one(::A594)
I don't see a requirement to define one
(or perhaps oneunit
) here https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tangents.html#Operations-on-a-tangent-type but perhaps there ought to be such a thing?