Skip to content

The right way to implement rrule(broadcasted, f, args...) #531

Closed
@dfdx

Description

@dfdx

In some other discussion I proposed a generic implementation of rrule for broadcasted, a slightly modified version of which looks like this (using rrule instead of rrule_via_ad for simplicity here):

# unzip taken from Zygote:
# https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]

@generated function _unzip(tuples, ::Val{N}) where {N}
  Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i  1:N)...)
end

function unzip(tuples)
  N = length(first(tuples))
  _unzip(tuples, Val(N))
end


function rrule(::typeof(Broadcast.broadcasted), f::F, args...) where F
    ys, pbs = unzip(rrule_via_ad.(f, args...))
    function pullback(Δ)
        dxs = map((pb, Δ) -> pb(Δ), pbs, Δ) |> unzip
        dxs = [all(dx .== NoTangent()) ? NoTangent() : dx for dx in dxs]
        return NoTangent(), dxs...
    end
    return ys, pullback
end

Empirically, I can see that it works correctly at least in simple cases, e.g.:

f = sin
xs = rand(2)

# manually get pullbacks for each element and apply them to seed 1.0
pbs = [rrule(f, x)[2] for x in xs]
dxs = [pbs[1](1.0)[2], pbs[2](1.0)[2]]

# use rrule for broadcasted
_, bcast_pb = rrule(Broadcast.broadcasted, f, xs)
dxs_bcast = bcast_pb(ones(2))[end]

@assert all(dxs .== dxs_bcast)

But when I run test_rrule(Broadcast.broadcasted, f, xs; check_inferred=false) I get a strange error:

test_rrule: broadcasted on typeof(sin),Vector{Float64}: Test Failed at /home/az/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:24
  Expression: isapprox(actual, expected; kwargs...)
  Problem:  Vector{ChainRulesCore.AbstractTangent}[1]
   Evaluated: isapprox(0.0, -8.379999999999999; rtol = 1.0e-9, atol = 1.0e-9)
Stacktrace:
 [1] test_approx(actual::Union{Number, AbstractArray{var"#s79", N} where {var"#s79"<:Number, N}}, expected::Union{Number, AbstractArray{var"#s87", N} where {var"#s87"<:Number, N}}, msg::Any; kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:24
 [2] test_approx(::ChainRulesCore.AbstractZero, x::Any, msg::Any; kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:33
 [3] test_approx(actual::AbstractArray, expected::AbstractArray, msg::Any; kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:80
 [4] macro expansion
   @ ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:238 [inlined]
 [5] macro expansion
   @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [6] test_rrule(::ChainRulesCore.RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:194
test_rrule: broadcasted on typeof(sin),Vector{Float64}: Error During Test at /home/az/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:191
  Got exception outside of a @test
  AssertionError: T <: NamedTuple
  Stacktrace:
    [1] test_approx(actual::ChainRulesCore.Tangent{Tuple{Vector{Float64}}, Tuple{Vector{Float64}}}, expected::Any, msg::Any; kwargs::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:112
    [2] test_approx(actual::AbstractArray, expected::AbstractArray, msg::Any; kwargs::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:80
    [3] macro expansion
      @ ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:238 [inlined]
    [4] macro expansion
      @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
    [5] test_rrule(::ChainRulesCore.RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:194
    [6] test_rrule(::Any, ::Vararg{Any, N} where N; kwargs::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:168
    [7] top-level scope
      @ REPL[20]:1
    [8] eval
      @ ./boot.jl:360 [inlined]
    [9] eval
      @ ./Base.jl:39 [inlined]
   [10] repleval(m::Module, code::Expr, #unused#::String)
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/repl.jl:157
   [11] (::VSCodeServer.var"#69#71"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/repl.jl:123
   [12] with_logstate(f::Function, logstate::Any)
      @ Base.CoreLogging ./logging.jl:491
   [13] with_logger
      @ ./logging.jl:603 [inlined]
   [14] (::VSCodeServer.var"#68#70"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/repl.jl:124
   [15] #invokelatest#2
      @ ./essentials.jl:708 [inlined]
   [16] invokelatest(::Any)
      @ Base ./essentials.jl:706
   [17] macro expansion
      @ ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/eval.jl:34 [inlined]
   [18] (::VSCodeServer.var"#53#54")()
      @ VSCodeServer ./task.jl:411

Note 2 error messages:

Evaluated: isapprox(0.0, -8.379999999999999; rtol = 1.0e-9, atol = 1.0e-9)

and

AssertionError: T <: NamedTuple

Can you see a mistake in this implementation or is it just too complicated for test_rrule() to verify?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions