Skip to content

@non_differentiable should use identical pullbacks when possible #678

@nsajko

Description

@nsajko

The pullbacks returned by @non_differentiable-generated rrule would ideally be identical for the same type signature. Presumably this could help compilation latency and type stability in user code.

Test:

f(x) = rand()*x*0.1
g(x) = rand()*x*0.2
using ChainRulesCore
@non_differentiable f(::Any)
@non_differentiable g(::Any)
using Test
@test last(rrule(f, 0.3)) === last(rrule(g, 0.4))

Failure:

julia> @test last(rrule(f, 0.3)) === last(rrule(g, 0.4))
Test Failed at REPL[7]:1
  Expression: last(rrule(f, 0.3)) === last(rrule(g, 0.4))
   Evaluated: var"#f_pullback#2"() === var"#g_pullback#4"()

ERROR: There was an error during testing

I'm not good with macros so I probably won't be tackling this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions