Skip to content

pushforward_function and pullback_function are confused by tuples vs single input #99

Open
@gdalle

Description

@gdalle

The setup:

julia> import AbstractDifferentiation as AD

julia> using ForwardDiff: ForwardDiff

julia> using Zygote: Zygote

julia> b1 = AD.ZygoteBackend();

julia> b2 = AD.ForwardDiffBackend();

julia> f(x) = x .^ 2;

julia> x = rand(3)
3-element Vector{Float64}:
 0.4953469957333393
 0.16373195021545772
 0.9601871509472656

julia> y = f(x)
3-element Vector{Float64}:
 0.24536864618204485
 0.026808151521357126
 0.921959364844227

julia> dx = rand(size(x)...)
3-element Vector{Float64}:
 0.5968881542176618
 0.05494767011762569
 0.18061398390944328

julia> dy = rand(size(y)...)
3-element Vector{Float64}:
 0.9491707280920829
 0.2878716471988746
 0.15674572721525504

A pushforward with Zygote backend doesn't accept a single array as input.

julia> pf1 = AD.pushforward_function(b1, f, x);

julia> pf2 = AD.pushforward_function(b2, f, x);

julia> pf1((dx,))  # works
([0.5913335079610738, 0.017993378376308967, 0.3468464532624872],)

julia> pf1(dx)  # fails but shouldn't
ERROR: ArgumentError: Tuple contains 3 elements, must contain exactly 1 element
Stacktrace:
 [1] only(x::Tuple{Float64, Float64, Float64})
   @ Base.Iterators ./iterators.jl:1531
 [2] (::AbstractDifferentiation.var"#14#16"{Vector{Float64}, typeof(f), Tuple{Vector{Float64}}})(::Float64, ::Vararg{Float64})
   @ AbstractDifferentiation ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:172
 [3] (::AbstractDifferentiation.var"#25#27"{AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context{false}}}, AbstractDifferentiation.var"#14#16"{Vector{Float64}, typeof(f), Tuple{Vector{Float64}}}, Tuple{Float64, Float64, Float64}})(ws::Nothing)
   @ AbstractDifferentiation ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:249
 [4] jacobian(::AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context{false}}}, ::Function, ::Float64, ::Float64, ::Vararg{Float64})
   @ AbstractDifferentiationChainRulesCoreExt ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:551
 [5] (::AbstractDifferentiation.var"#13#15"{AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context{false}}}, typeof(f), Tuple{Vector{Float64}}})(ds::Vector{Float64})
   @ AbstractDifferentiation ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:166
 [6] top-level scope
   @ ~/Work/GitHub/Julia/ImplicitDifferentiation.jl/test/playground.jl:54

julia> pf2((dx,))  # works
([0.5913335079610738, 0.017993378376308967, 0.3468464532624872],)

julia> pf2(dx)  # works
([0.5913335079610738, 0.017993378376308967, 0.3468464532624872],)

A pullback with ForwardDiff backend doesn't accept a tuple as input:

julia> pb1 = AD.pullback_function(b1, f, x);

julia> pb2 = AD.pullback_function(b2, f, x);

julia> pb1(dy)  # works
([0.9403377371968791, 0.09426757241521588, 0.301010466475946],)

julia> pb1((dy,))  # works
([0.9403377371968791, 0.09426757241521588, 0.301010466475946],)

julia> pb2(dy)  # works
([0.9403377371968791, 0.09426757241521588, 0.301010466475946],)

julia> pb2((dy,))  # fails but shouldn't
ERROR: AssertionError: length(vs) == length(ws)
Stacktrace:
 [1] (::AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)})(xs::Vector{ForwardDiff.Dual{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3}})
   @ AbstractDifferentiation ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:231
 [2] vector_mode_dual_eval!
   @ ~/.julia/packages/ForwardDiff/vXysl/src/apiutils.jl:24 [inlined]
 [3] vector_mode_gradient(f::AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3}}})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:89
 [4] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3}}}, ::Val{true})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:19
 [5] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3}}})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:17
 [6] gradient(ba::AbstractDifferentiation.ForwardDiffBackend{Nothing}, f::Function, x::Vector{Float64})
   @ AbstractDifferentiationForwardDiffExt ~/Work/GitHub/Julia/AbstractDifferentiation.jl/ext/AbstractDifferentiationForwardDiffExt.jl:46
 [7] (::AbstractDifferentiation.var"#87#89"{AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f), Tuple{Vector{Float64}}})(ws::Tuple{Vector{Float64}})
   @ AbstractDifferentiation ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:224
 [8] top-level scope
   @ ~/Work/GitHub/Julia/ImplicitDifferentiation.jl/test/playground.jl:63

Both of these uses are documented in the README

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions