-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rrule_via_ad // frule_via_ad Calling back into AD from ChainRules #68
Comments
use ChainRulesCore
Solution that @willtebbutt and I came up with the other night. struct Config{F, R}
ad_frule::F
ad_rrule::R
end Where We might want to specify that they have to be actual functions so that we can actually dispatch on them being present. function rrule(::typeof(map), config::Config{<:Any, <:Function}, f, xs)
y = map(f, xs...)
function map_pullback(ȳ)
ntuple(length(xs)+2) do full_i
full_i == 1 && return NO_FIELDS
full_i == 2 && return DoesNotExist()
i = full_i-2
@thunk map(ȳ, xs...) do ȳi, xis...
_, pullback = ad_rrule(f, xis...)
∂xis = pullback(ȳi)
extern(∂xis[i+1]) #+1 to skp ∂self
end
end
end
return y, map_pullback
end Which since we dispatched on One worry is the case that one is in some non-AD senario, but where one knows all the things should have rules. |
@jrevels 's old idea that was in the source code for ages and i have ment to transfer is
Which could work. It would basically fix any used of I know for purposes of allowing 2nd derivatives etc |
A slight variation on the Firstly, add a couple of extra types: abstract type AbstractAD{T} end
struct ForwardsModeAD{T} <: AbstractAD{T}
pushforward::T
end
struct ReverseModeAD{T} <: AbstractAD{T}
pullback::T
end These simply wrap an AD e.g. Implement a fallback definition of frule(::AbstractAD, tangents, values...) = frule(tangents, values...)
rrule(::AbstractAD, values...) = rrule(values...) This gives rule-authors two options.
This gives AD package authors options various options:
|
I played around with a similar idea for broadcasting in Yota: function ∇broadcasted(dy, ::typeof(Broadcast.broadcasted), f::F, args...) where F
df = get_deriv_function(Tuple{F, map(eltype, args)...})
return NO_FIELDS, df.(dy, f, args...)
end Here Any idea how to deal with array-of-tuples vs tuple-or-arrays issue?
|
Good question. |
julia> A = cu([(Zero(), 2, 3), (Zero(), 5, 6), (Zero(), 8, 9)])
3-element CuArray{Tuple{Zero, Int64, Int64}, 1}:
(Zero(), 2, 3)
(Zero(), 5, 6)
(Zero(), 8, 9)
julia> map(x -> x[1], A)
3-element CuArray{Zero, 1}:
Zero()
Zero()
Zero()
julia> map(x -> x[2], A)
3-element CuArray{Int64, 1}:
2
5
8
julia> map(x -> x[3], A)
3-element CuArray{Int64, 1}:
3
6
9 I'll try to do it for |
So far no luck with pullback-based systems: julia> f = *
* (generic function with 379 methods)
julia> args = map(cu, (rand(2), rand(2)))
(Float32[0.37061554, 0.97347444], Float32[0.96509105, 0.7939103])
julia> rrule.(f, args...)
ERROR: GPU broadcast resulted in non-concrete element type Union{}.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] copy
@ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
[3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}})
@ Base.Broadcast ./broadcast.jl:883
[4] top-level scope
@ REPL[3]:1 As far as I understand, CUDA.jl doesn't play well with any kind of closures, e.g. here's a simpler example without julia> x = 1.0
1.0
julia> foo(y) = x + y
foo (generic function with 1 method)
julia> foo.(rand(2))
2-element Vector{Float64}:
1.150095833562403
1.1280587660314911
julia> foo.(cu(rand(2)))
ERROR: GPU broadcast resulted in non-concrete element type Any.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] copy
@ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
[3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(foo), Tuple{CuArray{Float32, 1}}})
@ Base.Broadcast ./broadcast.jl:883
[4] top-level scope
@ REPL[18]:1 It might be possible to rewrite |
Does the second example work if |
Indeed, making julia> const f = *
* (generic function with 402 methods)
julia> const args = map(cu, (rand(2), rand(2)))
(Float32[0.08670729, 0.5492601], Float32[0.24855424, 0.8392036])
julia> rrule.(f, args...)
ERROR: GPU broadcast resulted in non-concrete element type Union{}.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] copy
@ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
[3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}})
@ Base.Broadcast ./broadcast.jl:883
[4] top-level scope
@ REPL[3]:1
julia> @code_warntype rrule.(f, args...)
Variables
#self#::Core.Const(var"##dotfunction#489#95"())
x1::Core.Const(*)
x2::Tuple{CuArray{Float32, 1}, CuArray{Float32, 1}}
Body::Union{}
1 ─ %1 = Core.tuple(Main.rrule, x1)::Core.Const((rrule, *))
│ %2 = Core._apply_iterate(Base.iterate, Base.broadcasted, %1, x2)::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}}
│ Base.materialize(%2)
└── Core.Const(:(return %3)) |
I just checked your CUDA example, and everything is inferred correctly for me (even without (jl_OUkHOF) pkg> st
Status `/tmp/jl_OUkHOF/Project.toml`
[052768ef] CUDA v3.1.0
[082447d4] ChainRules v0.7.63
julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i7-6850K CPU @ 3.60GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.1 (ORCJIT, broadwell)
Environment:
JULIA_NUM_THREADS = 12 |
Thanks for checking it! Indeed, updating to the latest versions of the packages fixed that specific example, using CUDA
foo(x::Number) = (2x, dy -> dy * x) # a simple rrule-like function
bar(x, dy) = ((y, pb) = foo(x); pb(x))
A = cu(rand(1024, 64))
bar.(A, 1f0)
# ok But that's not really my concern, what I'm trying to check is how wrapping functions into closures affects performance on GPU. For example, when I write: quux.(A) and Some concrete experiments with using ChainRules
using CUDA
using BenchmarkTools
args = map(cu, (rand(1024, 64), rand(1024, 64)))
r = rrule.(*, args...)
pbs = map(x -> x[2], r)
@btime map((pb, x) -> pb(x), $pbs, 1f0)
# ==> 18.085 μs (41 allocations: 976 bytes)
plain_deriv(dy, x, y) = (Zero(), dy * y, dy * x)
@btime plain_deriv.(1f0, args...)
# ==> 4.968 μs (26 allocations: 592 bytes) I must admit pullback-based examples also runs in 5μs if I use a proper But maybe I'm over-complicating things and in practice everything will work just fine. |
Here's an implementation of # from Zygote:
# https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
@generated function _unzip(tuples, ::Val{N}) where {N}
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i ∈ 1:N)...)
end
function unzip(tuples)
N = length(first(tuples))
_unzip(tuples, Val(N))
end
function rrule(::typeof(Broadcast.broadcasted), f::F, args...) where F
ys, pbs = unzip(rrule.(f, args...))
function pullback(Δ)
dxs = map((pb, Δ) -> pb(Δ), pbs, Δ)
return NO_FIELDS, unzip(dxs)...
end
return ys, pullback
end I didn't notice any significant performance degradation compared to non-closure-based version, but rrule.(^, cu(rand(2)), 2f0) Since |
This was originally discussed in JuliaDiff/ChainRules.jl#12 (comment)
and in a few other places.
Basically often when defining a chainrule (
frule
, orrrule
)it would be nice to be able to say "Give me this chainrule for some of the function, and if there is not one predefined us AD to get it"
as part of your rule definition.
Right now this is not ppossible, except by hard coding a AD (e.g. Zygote) in.
Where as if we had a function that an AD system could basically overload,
then we could do that.
It would also provide a common API for all ADs that support it.
This would help e.g. with higher order functions like
map
,broadcast
etcJuliaDiff/ChainRules.jl#122
There is some fiddlyness involved, around making sure it is both overloadable multiple times and so the user can choise which AD, and that compiles away, but I think we can sort it all out.
@jessebett reminded me of this today
The text was updated successfully, but these errors were encountered: