Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rrule_via_ad // frule_via_ad Calling back into AD from ChainRules #68

Closed
oxinabox opened this issue Nov 21, 2019 · 12 comments · Fixed by #363
Closed

rrule_via_ad // frule_via_ad Calling back into AD from ChainRules #68

oxinabox opened this issue Nov 21, 2019 · 12 comments · Fixed by #363
Milestone

Comments

@oxinabox
Copy link
Member

This was originally discussed in JuliaDiff/ChainRules.jl#12 (comment)
and in a few other places.

Basically often when defining a chainrule (frule, or rrule)
it would be nice to be able to say "Give me this chainrule for some of the function, and if there is not one predefined us AD to get it"
as part of your rule definition.
Right now this is not ppossible, except by hard coding a AD (e.g. Zygote) in.

Where as if we had a function that an AD system could basically overload,
then we could do that.

It would also provide a common API for all ADs that support it.

This would help e.g. with higher order functions like map, broadcast etc
JuliaDiff/ChainRules.jl#122

There is some fiddlyness involved, around making sure it is both overloadable multiple times and so the user can choise which AD, and that compiles away, but I think we can sort it all out.


@jessebett reminded me of this today

YingboMa pushed a commit to YingboMa/ChainRulesCore.jl that referenced this issue Dec 21, 2019
@oxinabox oxinabox added this to the v1 milestone Dec 25, 2019
@oxinabox
Copy link
Member Author

oxinabox commented Jan 4, 2020

Solution that @willtebbutt and I came up with the other night.
Make all rules take an extra argument for the configuration.
Which would be a struct:

struct Config{F, R}
    ad_frule::F
    ad_rrule::R
end

Where ad_frule and ad_rrule are functions conforming to the frule/rrule API, but that invoke a forward mode or reverse more AD.
Or they can be nothing if either one is not provided.

We might want to specify that they have to be actual functions so that we can actually dispatch on them being present.
e.g. for rrule

function rrule(::typeof(map), config::Config{<:Any, <:Function}, f, xs)
    y = map(f, xs...)
    function map_pullback(ȳ)
        ntuple(length(xs)+2) do full_i
            full_i == 1 && return NO_FIELDS
            full_i == 2 && return DoesNotExist()
            i = full_i-2
            @thunk map(ȳ, xs...) do ȳi, xis...
                _, pullback = ad_rrule(f, xis...)
                ∂xis = pullback(ȳi)
                extern(∂xis[i+1])  #+1 to skp ∂self
            end
        end
    end
    return y, map_pullback
end

Which since we dispatched on Confiig having a ad_rrule function it will hit the generic nothing rule-not-found fallback if it doesn't.
If it does, then we can assume the ad_rrule will itself check for actual rrules.

One worry is the case that one is in some non-AD senario, but where one knows all the things should have rules.
I think for that case the user can set the Config to use the current checked_rrule/ checked_ffrule which errors if the rule is not found.

@oxinabox
Copy link
Member Author

oxinabox commented Jan 5, 2020

@jrevels 's old idea that was in the source code for ages and i have ment to transfer is

In some weird ideal sense, the fallback for e.g. frule should actually be "get
the derivative via forward-mode AD". This is necessary to enable mixed-mode
rules, where e.g. frule is used within a rrule definition. For example,
broadcasted functions may not themselves be forward-mode primitives, but are
often forward-mode differentiable.
ChainRulesCore, by design, is decoupled from any specific AD implementation. How,
then, do we know which AD to fall back to when there isn't a primitive defined?
Well, if you're a greedy AD implementation, you can just overload frule and/or
rrule to use your AD directly. However, this won't play nice with other AD
packages doing the same thing, and thus could cause load-order-dependent
problems for downstream users.
It turns out, Cassette solves this problem nicely by allowing AD authors to
overload the fallbacks w.r.t. their own context. Example using ForwardDiff:

using ChainRulesCore, ForwardDiff, Cassette
Cassette.@context MyChainRuleCtx
# ForwardDiff, itself, can call `my_frule` instead of
# `frule` to utilize the ForwardDiff-injected ChainRulesCore
# infrastructure
my_frule(args...) = Cassette.recurse(MyChainRuleCtx(), frule, args...)
function Cassette.overdub(::MyChainRuleCtx, ::typeof(frule), f, x::Number)
    r = frule(f, x)
    if isa(r, Nothing)
        fx, df = (f(x), (_, Δx) -> ForwardDiff.derivative(f, x) * Δx)
    else
        fx, df = r
    end
    return fx, df
end

Which could work. It would basically fix any used of checked_rrule and checked_frule in existing code, but rewriting them not to ever error since frule would never return nothing.

I know for purposes of allowing 2nd derivatives etc
@YingboMa is already overdubbiing frule in ForwardDiff2
Which either makes this easier, because its just add an extra thing to existing overdub,
Or harder, because it ends up adding an extra layer of overdub, and bad things to performance happen when you nest Cassette.

@willtebbutt
Copy link
Member

willtebbutt commented May 13, 2020

A slight variation on the Config option discussed above.

Firstly, add a couple of extra types:

abstract type AbstractAD{T} end

struct ForwardsModeAD{T} <: AbstractAD{T}
    pushforward::T
end

struct ReverseModeAD{T} <: AbstractAD{T}
    pullback::T
end

These simply wrap an AD e.g. Zygote.pullback to produce a thing that ChainRules knows about. ChainRules will then assume an API for the function.

Implement a fallback definition of frule and rrule:

frule(::AbstractAD, tangents, values...) = frule(tangents, values...)

rrule(::AbstractAD, values...) = rrule(values...)

This gives rule-authors two options.

  1. implement a new method of frule or rrule that completely ignores any AD information.
  2. implement a new method of frule or rrule that exploits whichever AD is passed in.

This gives AD package authors options various options:

  • default to always passing in their AD
  • default to not passing in any AD. They'll still get the AD-independent rules, but will have to rely on their AD being sufficiently clever to differentiate through stuff.
  • any mixture of the above, and elect to use different ADs for particular functions e.g. the forwards-in-reverse trick for maping and broadcasting unary functions of a scalar.

@dfdx
Copy link

dfdx commented May 7, 2021

I played around with a similar idea for broadcasting in Yota:

function ∇broadcasted(dy, ::typeof(Broadcast.broadcasted), f::F, args...) where F
    df = get_deriv_function(Tuple{F, map(eltype, args)...})
    return NO_FIELDS, df.(dy, f, args...)
end

Here get_deriv_function() is a hardcoded version of ad_rrule, so basically we retrieve rrule* for args elements and broadcast it to all arguments. It kinda works, but since rrule returns a tuple, df.(dy, f, args...) returns an array of tuples and not tuple of arrays. CPU arrays it's may not a big deal, but computing per-element derivatives and then combining them back on GPU will definitely destroy the performance.

Any idea how to deal with array-of-tuples vs tuple-or-arrays issue?


*- strictly speaking, it's not rrule, but a function with a similar signature and semantics.

@oxinabox
Copy link
Member Author

oxinabox commented May 7, 2021

Good question.
Zgyote does this stuff with StaticGetter and unzip
idk how well optimied it is for GPU.
(it is much better than naive unzipping vis zip on CPU)
https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185

@dfdx
Copy link

dfdx commented May 7, 2021

StaticGetter still hits getindex(), but the following seems to work:

julia> A = cu([(Zero(), 2, 3), (Zero(), 5, 6), (Zero(), 8, 9)])
3-element CuArray{Tuple{Zero, Int64, Int64}, 1}:
 (Zero(), 2, 3)
 (Zero(), 5, 6)
 (Zero(), 8, 9)

julia> map(x -> x[1], A)
3-element CuArray{Zero, 1}:
 Zero()
 Zero()
 Zero()

julia> map(x -> x[2], A)
3-element CuArray{Int64, 1}:
 2
 5
 8

julia> map(x -> x[3], A)
3-element CuArray{Int64, 1}:
 3
 6
 9

I'll try to do it for rrule() assuming ad_rrule is provided.

@dfdx
Copy link

dfdx commented May 8, 2021

So far no luck with pullback-based systems:

julia> f = *
* (generic function with 379 methods)

julia> args = map(cu, (rand(2), rand(2)))
(Float32[0.37061554, 0.97347444], Float32[0.96509105, 0.7939103])

julia> rrule.(f, args...)
ERROR: GPU broadcast resulted in non-concrete element type Union{}.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] copy
   @ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
 [3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}})
   @ Base.Broadcast ./broadcast.jl:883
 [4] top-level scope
   @ REPL[3]:1

As far as I understand, CUDA.jl doesn't play well with any kind of closures, e.g. here's a simpler example without rrule:

julia> x = 1.0
1.0

julia> foo(y) = x + y
foo (generic function with 1 method)

julia> foo.(rand(2))
2-element Vector{Float64}:
 1.150095833562403
 1.1280587660314911

julia> foo.(cu(rand(2)))
ERROR: GPU broadcast resulted in non-concrete element type Any.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] copy
   @ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
 [3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(foo), Tuple{CuArray{Float32, 1}}})
   @ Base.Broadcast ./broadcast.jl:883
 [4] top-level scope
   @ REPL[18]:1

It might be possible to rewrite rrule with something like Cassette to replace all calls to f(args...) with broadcast(f, args...) (or equivalent for other higher order functions), but it doesn't sound very robust.

@devmotion
Copy link
Member

Does the second example work if x is a constant? I guess this should fix the type instability.

@dfdx
Copy link

dfdx commented May 8, 2021

Indeed, making x a constant in global scope fixes the issue in the example. For rrule it still doesn't work though:

julia> const f = *
* (generic function with 402 methods)

julia> const args = map(cu, (rand(2), rand(2)))
(Float32[0.08670729, 0.5492601], Float32[0.24855424, 0.8392036])

julia> rrule.(f, args...)
ERROR: GPU broadcast resulted in non-concrete element type Union{}.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] copy
   @ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
 [3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}})
   @ Base.Broadcast ./broadcast.jl:883
 [4] top-level scope
   @ REPL[3]:1

julia> @code_warntype rrule.(f, args...)
Variables
  #self#::Core.Const(var"##dotfunction#489#95"())
  x1::Core.Const(*)
  x2::Tuple{CuArray{Float32, 1}, CuArray{Float32, 1}}

Body::Union{}
1%1 = Core.tuple(Main.rrule, x1)::Core.Const((rrule, *))
│   %2 = Core._apply_iterate(Base.iterate, Base.broadcasted, %1, x2)::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}}
│        Base.materialize(%2)
└──      Core.Const(:(return %3))

@devmotion
Copy link
Member

I just checked your CUDA example, and everything is inferred correctly for me (even without const). I used

(jl_OUkHOF) pkg> st
      Status `/tmp/jl_OUkHOF/Project.toml`
  [052768ef] CUDA v3.1.0
  [082447d4] ChainRules v0.7.63

julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-6850K CPU @ 3.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, broadwell)
Environment:
  JULIA_NUM_THREADS = 12

@dfdx
Copy link

dfdx commented May 13, 2021

Thanks for checking it! Indeed, updating to the latest versions of the packages fixed that specific example, but unfortunately not more complex ones and even for more complex examples, e.g.:

using CUDA

foo(x::Number) = (2x, dy -> dy * x)   # a simple rrule-like function
bar(x, dy) = ((y, pb) = foo(x); pb(x))
A = cu(rand(1024, 64))

bar.(A, 1f0)
# ok

But that's not really my concern, what I'm trying to check is how wrapping functions into closures affects performance on GPU. For example, when I write:

quux.(A)

and quux() is a plain old function I'm pretty much sure CUDA.jl will be able to generate efficient kernel from it and apply this kernel to each element of A. However, in the example above foo.(A) returns an array of (tuples with) closures. Closure is a CPU object wrapping a scalar (single element of GPU array), and it doesn't sound very GPU-friendly. Even if run benchmarks and on simple examples they show good performance, we should be really careful not to accidentally kill CUDA's optimizations on more complex examples.


Some concrete experiments with rrule, perhaps not the most optimal:

using ChainRules
using CUDA
using BenchmarkTools

args = map(cu, (rand(1024, 64), rand(1024, 64)))
r = rrule.(*, args...)
pbs = map(x -> x[2], r)
@btime map((pb, x) -> pb(x), $pbs, 1f0)
# ==> 18.085 μs (41 allocations: 976 bytes)

plain_deriv(dy, x, y) = (Zero(), dy * y, dy * x)
@btime plain_deriv.(1f0, args...)
# ==> 4.968 μs (26 allocations: 592 bytes)

I must admit pullback-based examples also runs in 5μs if I use a proper dy = cu(ones(1024, 64)) instead of 1f0, yet the behavior above is quite unintuitive for me.

But maybe I'm over-complicating things and in practice everything will work just fine.

@dfdx
Copy link

dfdx commented May 20, 2021

Here's an implementation of rrule for broadcasted() which works with CPU and GPU arrays as long as rrule.(f, args...) works:

# from Zygote:
# https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]

@generated function _unzip(tuples, ::Val{N}) where {N}
  Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i  1:N)...)
end

function unzip(tuples)
  N = length(first(tuples))
  _unzip(tuples, Val(N))
end


function rrule(::typeof(Broadcast.broadcasted), f::F, args...) where F
    ys, pbs = unzip(rrule.(f, args...))
    function pullback(Δ)
        dxs = map((pb, Δ) -> pb(Δ), pbs, Δ)
        return NO_FIELDS, unzip(dxs)...
    end
    return ys, pullback
end

I didn't notice any significant performance degradation compared to non-closure-based version, but rrule.() fails on some examples e.g.:

 rrule.(^, cu(rand(2)), 2f0)

Since rrule.() is a placeholder for ad_rrule() (or whatever we end up with) and ad_rrule() may behave differently, I just stopped here and haven't investigated the error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants