Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to identify rules that always return AbstractZero #248

Open
oxinabox opened this issue Nov 13, 2020 · 2 comments
Open

Ability to identify rules that always return AbstractZero #248

oxinabox opened this issue Nov 13, 2020 · 2 comments
Labels
design Requires some desgin before changes are made rule definition helper relating to helpers for declaring rules

Comments

@oxinabox
Copy link
Member

This related to @non_differentiable
I think for operator overloading based AD,
if a rule's propagator is always going to return a AbstractZero the correct thing to do quiet different.
One wants to accept the overloaded type, but return a non-overloaded type

@willtebbutt
Copy link
Member

I think this sounds reasonable, but I'm having trouble saying for sure without a concrete example. Any chance that you could concoct one?

@oxinabox oxinabox added design Requires some desgin before changes are made rule definition helper relating to helpers for declaring rules labels Dec 9, 2020
@oxinabox
Copy link
Member Author

oxinabox commented Dec 9, 2020

Consider size

The code that Nabla would generate right now from our
@nondifferentiable size(::AbstractArray)
is:

function Base.size(x1::Node{<:AbstractArray{N}}; kwargs...) where N
    (primal_val, pullback) = rrule(size, unbox(x1); kwargs...)
    tape = tape(x1)
    branch = Branch(primal_val, size, (x1,), kwargs.data, tape, length(tape) + 1, pullback)
    push!(tape, branch)
    return branch  # type is <:Node{NTuple{N, Int}}
end
@inline function preprocess(
    ::typeof(size), y::Branch, ȳ, x1::Union{Any, Node{<:Any}}
)
    return pullback(ȳ)  # this will actually just return `NO_FIELDS, DoesNotExist()`
end
@inline function (
    ::typeof(size), ::Type{Arg{N}}, p, ::Any, ::Any, x1::Union{Any, Node{<:Any}};
    kwargs...
) where N
    return p[N + 1]  # skip dself (N==1) as we don't support functors
end

But what we really want to do is:

function Base.size(x1::Node{<:AbstractArray{N}}; kwargs...) where N
    return size(unbox(x1))  # type is NTuple{N, Int}
end

Possibly we want some API like:

cotangent_types(sig_type_tuple)

That defaults to returning Tuple{Any, Any, ...} (or even something smarter?)
but that @nondifferentiable overloads to be Tuple{Zero, DoesNotExist, DoesNotExist}
so that when generating rules we can decide to just run the primal.
OTOH, we could maybe pull this information out of type inference, and use some Tricks.jl trick to do that without it being super expensive, idk.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Requires some desgin before changes are made rule definition helper relating to helpers for declaring rules
Projects
None yet
Development

No branches or pull requests

2 participants