Closed
Description
In some other discussion I proposed a generic implementation of rrule
for broadcasted
, a slightly modified version of which looks like this (using rrule
instead of rrule_via_ad
for simplicity here):
# unzip taken from Zygote:
# https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
@generated function _unzip(tuples, ::Val{N}) where {N}
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i ∈ 1:N)...)
end
function unzip(tuples)
N = length(first(tuples))
_unzip(tuples, Val(N))
end
function rrule(::typeof(Broadcast.broadcasted), f::F, args...) where F
ys, pbs = unzip(rrule_via_ad.(f, args...))
function pullback(Δ)
dxs = map((pb, Δ) -> pb(Δ), pbs, Δ) |> unzip
dxs = [all(dx .== NoTangent()) ? NoTangent() : dx for dx in dxs]
return NoTangent(), dxs...
end
return ys, pullback
end
Empirically, I can see that it works correctly at least in simple cases, e.g.:
f = sin
xs = rand(2)
# manually get pullbacks for each element and apply them to seed 1.0
pbs = [rrule(f, x)[2] for x in xs]
dxs = [pbs[1](1.0)[2], pbs[2](1.0)[2]]
# use rrule for broadcasted
_, bcast_pb = rrule(Broadcast.broadcasted, f, xs)
dxs_bcast = bcast_pb(ones(2))[end]
@assert all(dxs .== dxs_bcast)
But when I run test_rrule(Broadcast.broadcasted, f, xs; check_inferred=false)
I get a strange error:
test_rrule: broadcasted on typeof(sin),Vector{Float64}: Test Failed at /home/az/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Problem: Vector{ChainRulesCore.AbstractTangent}[1]
Evaluated: isapprox(0.0, -8.379999999999999; rtol = 1.0e-9, atol = 1.0e-9)
Stacktrace:
[1] test_approx(actual::Union{Number, AbstractArray{var"#s79", N} where {var"#s79"<:Number, N}}, expected::Union{Number, AbstractArray{var"#s87", N} where {var"#s87"<:Number, N}}, msg::Any; kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:24
[2] test_approx(::ChainRulesCore.AbstractZero, x::Any, msg::Any; kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:33
[3] test_approx(actual::AbstractArray, expected::AbstractArray, msg::Any; kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:80
[4] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:238 [inlined]
[5] macro expansion
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[6] test_rrule(::ChainRulesCore.RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:194
test_rrule: broadcasted on typeof(sin),Vector{Float64}: Error During Test at /home/az/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:191
Got exception outside of a @test
AssertionError: T <: NamedTuple
Stacktrace:
[1] test_approx(actual::ChainRulesCore.Tangent{Tuple{Vector{Float64}}, Tuple{Vector{Float64}}}, expected::Any, msg::Any; kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:112
[2] test_approx(actual::AbstractArray, expected::AbstractArray, msg::Any; kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:80
[3] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:238 [inlined]
[4] macro expansion
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[5] test_rrule(::ChainRulesCore.RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:194
[6] test_rrule(::Any, ::Vararg{Any, N} where N; kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:168
[7] top-level scope
@ REPL[20]:1
[8] eval
@ ./boot.jl:360 [inlined]
[9] eval
@ ./Base.jl:39 [inlined]
[10] repleval(m::Module, code::Expr, #unused#::String)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/repl.jl:157
[11] (::VSCodeServer.var"#69#71"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/repl.jl:123
[12] with_logstate(f::Function, logstate::Any)
@ Base.CoreLogging ./logging.jl:491
[13] with_logger
@ ./logging.jl:603 [inlined]
[14] (::VSCodeServer.var"#68#70"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/repl.jl:124
[15] #invokelatest#2
@ ./essentials.jl:708 [inlined]
[16] invokelatest(::Any)
@ Base ./essentials.jl:706
[17] macro expansion
@ ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/eval.jl:34 [inlined]
[18] (::VSCodeServer.var"#53#54")()
@ VSCodeServer ./task.jl:411
Note 2 error messages:
Evaluated: isapprox(0.0, -8.379999999999999; rtol = 1.0e-9, atol = 1.0e-9)
and
AssertionError: T <: NamedTuple
Can you see a mistake in this implementation or is it just too complicated for test_rrule()
to verify?
Metadata
Metadata
Assignees
Labels
No labels