Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot generate frule seed via one(x) #618

Open
mcabbott opened this issue Apr 10, 2023 · 1 comment · Fixed by JuliaDiff/ChainRules.jl#710
Open

Cannot generate frule seed via one(x) #618

mcabbott opened this issue Apr 10, 2023 · 1 comment · Fixed by JuliaDiff/ChainRules.jl#710
Labels
design Requires some desgin before changes are made enhancement New feature or request

Comments

@mcabbott
Copy link
Member

Calling frule_via_ad(cfg, (NoTangent(), one(x)), f, x) to work out the derivative works for numbers but not in general. So this path:

https://github.com/JuliaDiff/ChainRules.jl/blob/5855c10bdbe691fc07822752f5b5865b9cea44d3/src/rulesets/Base/broadcast.jl#L104-L110

fails if broadcasting over an array of Ref, or almost any struct:

julia> using ChainRules, ChainRulesTestUtils

julia> CFG = ChainRulesTestUtils.TestConfig();

julia> ChainRules.split_bc_forwards(CFG, only, [Ref(1.0), Ref(2.0)])
ERROR: MethodError: no method matching one(::Base.RefValue{Float64})
Stacktrace:
  [1] (::ChainRules.var"#1732#1734"{typeof(frule_via_ad), ChainRulesTestUtils.TestConfig, typeof(only)})(a::Base.RefValue{Float64})
    @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:109
...
 [10] StructArray
    @ ~/.julia/packages/StructArrays/dNQpc/src/structarray.jl:254 [inlined]
 [11] unzip_broadcast(f::ChainRules.var"#1732#1734"{typeof(frule_via_ad), ChainRulesTestUtils.TestConfig, typeof(only)}, args::Vector{Base.RefValue{Float64}})
    @ ChainRules ~/.julia/dev/ChainRules/src/unzipped.jl:40
 [12] split_bc_inner
    @ ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:108 [inlined]
 [13] split_bc_forwards(cfg::ChainRulesTestUtils.TestConfig, f::typeof(only), arg::Vector{Base.RefValue{Float64}})

julia> struct A594 x::Float64 end;

julia> ChainRules.split_bc_forwards(CFG, x -> x.x, A594.(1:3))
ERROR: MethodError: no method matching one(::A594)

I don't see a requirement to define one (or perhaps oneunit) here https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tangents.html#Operations-on-a-tangent-type but perhaps there ought to be such a thing?

@oxinabox
Copy link
Member

I started to work on code that can define a basis seed here:
https://github.com/JuliaComputing/Humpty.jl/blob/main/src/basis.jl
Maybe it is something we should think about.

@oxinabox oxinabox added enhancement New feature or request design Requires some desgin before changes are made labels Apr 17, 2023
@oxinabox oxinabox transferred this issue from JuliaDiff/ChainRules.jl Apr 17, 2023
@oxinabox oxinabox reopened this May 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Requires some desgin before changes are made enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants