Skip to content

Ignoring non-differentiable elements is not applied to subfields #194

@oxinabox

Description

@oxinabox

Consider this re-implementation of first

using ChainRulesCore
using ChainRulesTestUtils

_first((x,y)) = x
function ChainRulesCore.rrule(::typeof(_first), xy)
    _first_pullback(dy) = (NoTangent(), Tangent{typeof(xy)}(dy, NoTangent()))
    return first(xy), _first_pullback
end

Now consider testing it on a tuple where the 2nd element is something non-differentiable

julia> test_rrule(_first, (1.5, "a"))
fd_cotangent = Tangent{Tuple{Float64, String}}(-5.409999999999906, "a")
test_rrule: _first on Float64,String: Error During Test at /home/oxinabox/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:191
  Got exception outside of a @test
  MethodError: no method matching zero(::String)
  Closest candidates are:
    zero(::Union{Type{P}, P}) where P<:Dates.Period at /usr/local/src/julia/julia-1.6/usr/share/julia/stdlib/v1.6/Dates/src/periods.jl:53
    zero(::SA) where SA<:StaticArrays.StaticArray at /home/oxinabox/.julia/packages/StaticArrays/AHT47/src/linalg.jl:88
    zero(::SparseArrays.AbstractSparseArray) at /usr/local/src/julia/julia-1.6/usr/share/julia/stdlib/v1.6/SparseArrays/src/SparseArrays.jl:55
    ...
  Stacktrace:
    [1] test_approx(::NoTangent, x::String, msg::String; kwargs::Base.Iterators.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})
      @ ChainRulesTestUtils ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/check_result.jl:33
    [2] test_approx(actual::Tangent{Tuple{Float64, String}, Tuple{Float64, NoTangent}}, expected::Tangent{Tuple{Float64, String}, Tuple{Float64, String}}, msg::String; kwargs::Base.Iterators.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})
      @ ChainRulesTestUtils ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/check_result.jl:92
    [3] macro expansion
      @ ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:239 [inlined]
    [4] macro expansion
      @ /usr/local/src/julia/julia-1.6/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
    [5] test_rrule(config::ChainRulesTestUtils.ADviaRuleConfig, f::typeof(_first), args::Tuple{Float64, String}; output_tangent::ChainRulesTestUtils.Auto, check_thunked_output_tangent::Bool, fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, rrule_f::Function, check_inferred::Bool, fkwargs::NamedTuple{(), Tuple{}}, rtol::Float64, atol::Float64, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
      @ ChainRulesTestUtils ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:194
    [6] test_rrule(config::ChainRulesTestUtils.ADviaRuleConfig, f::Function, args::Tuple{Float64, String})
      @ ChainRulesTestUtils ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:186
    [7] #test_rrule#47
      @ ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:168 [inlined]
    [8] test_rrule(::Function, ::Tuple{Float64, String})
      @ ChainRulesTestUtils ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:167
    [9] top-level scope

This is probably related to to_vec preserving non-differentiable elements
(which it probably needs to to make sure some some matrix factorizations work? idk)

julia> x, from_vec = ChainRulesTestUtils.to_vec((1.5, "a"));

julia> from_vec(x)
(1.5, "a")

and us still comparing against it when the matching field is NoTangent()

julia> ChainRulesTestUtils.auto_primal_and_tangent((1.5, "a")).tangent
Tangent{Tuple{Float64, String}}(8.91, NoTangent())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions