-
Notifications
You must be signed in to change notification settings - Fork 17
Open
Description
Consider this re-implementation of first
using ChainRulesCore
using ChainRulesTestUtils
_first((x,y)) = x
function ChainRulesCore.rrule(::typeof(_first), xy)
_first_pullback(dy) = (NoTangent(), Tangent{typeof(xy)}(dy, NoTangent()))
return first(xy), _first_pullback
endNow consider testing it on a tuple where the 2nd element is something non-differentiable
julia> test_rrule(_first, (1.5, "a"))
fd_cotangent = Tangent{Tuple{Float64, String}}(-5.409999999999906, "a")
test_rrule: _first on Float64,String: Error During Test at /home/oxinabox/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:191
Got exception outside of a @test
MethodError: no method matching zero(::String)
Closest candidates are:
zero(::Union{Type{P}, P}) where P<:Dates.Period at /usr/local/src/julia/julia-1.6/usr/share/julia/stdlib/v1.6/Dates/src/periods.jl:53
zero(::SA) where SA<:StaticArrays.StaticArray at /home/oxinabox/.julia/packages/StaticArrays/AHT47/src/linalg.jl:88
zero(::SparseArrays.AbstractSparseArray) at /usr/local/src/julia/julia-1.6/usr/share/julia/stdlib/v1.6/SparseArrays/src/SparseArrays.jl:55
...
Stacktrace:
[1] test_approx(::NoTangent, x::String, msg::String; kwargs::Base.Iterators.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})
@ ChainRulesTestUtils ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/check_result.jl:33
[2] test_approx(actual::Tangent{Tuple{Float64, String}, Tuple{Float64, NoTangent}}, expected::Tangent{Tuple{Float64, String}, Tuple{Float64, String}}, msg::String; kwargs::Base.Iterators.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})
@ ChainRulesTestUtils ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/check_result.jl:92
[3] macro expansion
@ ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:239 [inlined]
[4] macro expansion
@ /usr/local/src/julia/julia-1.6/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[5] test_rrule(config::ChainRulesTestUtils.ADviaRuleConfig, f::typeof(_first), args::Tuple{Float64, String}; output_tangent::ChainRulesTestUtils.Auto, check_thunked_output_tangent::Bool, fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, rrule_f::Function, check_inferred::Bool, fkwargs::NamedTuple{(), Tuple{}}, rtol::Float64, atol::Float64, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ChainRulesTestUtils ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:194
[6] test_rrule(config::ChainRulesTestUtils.ADviaRuleConfig, f::Function, args::Tuple{Float64, String})
@ ChainRulesTestUtils ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:186
[7] #test_rrule#47
@ ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:168 [inlined]
[8] test_rrule(::Function, ::Tuple{Float64, String})
@ ChainRulesTestUtils ~/JuliaEnvs/ChainRulesWorld/ChainRulesTestUtils.jl/src/testers.jl:167
[9] top-level scopeThis is probably related to to_vec preserving non-differentiable elements
(which it probably needs to to make sure some some matrix factorizations work? idk)
julia> x, from_vec = ChainRulesTestUtils.to_vec((1.5, "a"));
julia> from_vec(x)
(1.5, "a")and us still comparing against it when the matching field is NoTangent()
julia> ChainRulesTestUtils.auto_primal_and_tangent((1.5, "a")).tangent
Tangent{Tuple{Float64, String}}(8.91, NoTangent())Metadata
Metadata
Assignees
Labels
No labels