-
Notifications
You must be signed in to change notification settings - Fork 64
Add derivatives_given_input
#456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #456 +/- ##
==========================================
- Coverage 92.94% 92.89% -0.06%
==========================================
Files 14 14
Lines 794 802 +8
==========================================
+ Hits 738 745 +7
- Misses 56 57 +1
Continue to review full report at Codecov.
|
return @strip_linenos quote | ||
function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) | ||
given_output = @strip_linenos quote | ||
@inline function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inlining here is probably a good idea regardless of whether one wants derivatives_given_input
. Maybe it can be added in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot about that. Will fiddle a bit more to see if these matter.
given_input = @strip_linenos quote | ||
@inline function ChainRulesCore.derivatives_given_input(::Core.Typeof($f), $(inputs...)) | ||
$(__source__) | ||
$(setup_stmts...) | ||
return $(Expr(:tuple, partials...)) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some cases it's it might also be useful to know if the input is no longer necessary given the outputs. Maybe a more general solution could be to have something like a trait needed_for_derivatives(Ω, f, inputs...)
returning a pair of booleans (whether the input is used and whether the output is used).
Then things like sum
could check exactly what variables derivatives_given_output
requires, and fill everything else with nothing
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
Re not needing the input, my example here is FluxML/NNlib.jl#346 where I want activate!(f, x) = f.(x)
which is free to over-write x
. It has, right now, a hard-coded list of functions. Are there other uses for this?
Re how to tell, agree we could have some indicator beyond checking for methods. In JuliaDiff/ChainRules.jl#529 it seems convenient to do this based on f, typeof(x)
. I think, but have not checked, that using eltype
not first
will avoid problems with CuArrays. Although there are other ways.
For some functions you have a choice about whether to write it with Ω
or x
or both. I don't know how far down that path we want to go. When is re-computing Ω
more efficient than saving it, do we need a cost model? Right now if it just looks for the presence of symbol :Ω
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this:
sum could check exactly what variables derivatives_given_output requires, and fill everything else with nothing.
are you thinking about something like sum(sincos, xs)
? i.e. a function with multiple returns? No, that won't work, as you can't add tuples, it would have to be mapreduce(sincos, tadd, xs)
. My inclination is to say this is too obscure, multiple-return functions are pretty rare (and IMO would ideally not be handled by @scalar_rule
, there are so few that doing them by hand might even save lines of code).
For mapreduce(f, +, x, y, z)
I guess knowing that the gradient doesn't need (say) z
could let you avoid closing over that. But also seems a little obscure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I worded that incorrectly. I meant to say that needed_for_derivatives(Ω, f, inputs...)
would mention whether the output overall is used and whether the input overall is used (not differentiating between different components of the output or different inputs). The main use case for "does not need inputs" is indeed fused linear + pointwise nonlinearity kernels (eg Dense
or Conv
). It is only one use case, but it's the key component of most neural networks, so it probably deserves this performance optimization.
I was not completely sure whether it would make sense to distinguish whether a @scalar_rule
needs all the inputs or only some of them, but I am also coming to the conclusion that differentiating between those would be too obscure.
From a practical perspective, I wonder how easy it is by looking at the expression in @scalar_rule
to tell that the inputs are not used. There may be some annoyances due to destructuring, but it should be possible to figure out whether some of the inputs are used or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
One advantage of just checking whether methods exist, instead of having a special indicator function, is that it's harder to lie. We will probably want some methods not defined through the macro, such as:
One more issue with the business of the macro looking for Omega (or x) in the expressions is that sometimes there is a choice, e.g. relu
you can use Ω>0
or x>0
& the automatic rule will be efficient in different operations. Although maybe such functions are rare and we can just occasionally add a method by hand.
If you do have an indicator needed_for_derivatives
, then you could just have one function and call derivatives_given_stuff(nothing, f, x)
or (y, f, nothing)
. If you don't, then there need to be 3 functions, derivatives_given_input
, derivatives_given_output
, derivatives_given_both
. (Although I'd vote for shorter names... can these be variants of srule
, s for scalar? srule_in
, srule_out
, srule_out_in
, or something.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, even with just one function, if the way you check whether a method exists is type inference, then it will already tell you whether the result is used:
julia> Core.Compiler._return_type(derivatives_given_output, Tuple{Nothing, typeof(log), Float64})
Tuple{Tuple{Float64}}
julia> Core.Compiler._return_type(derivatives_given_output, Tuple{Nothing, typeof(sqrt), Float64})
Union{}
JuliaDiff/ChainRules.jl#529 now does this, instead of using this PR.
Won't help with the two-rules-for- You can define e.g. relu
story, though.derivatives_given_output(::Nothing, ::typeof(relu), x::Real)
to target one use or the other. If everyone uses nothing
as the marker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more issue with the business of the macro looking for
Omega
(orx
) in the expressions is that sometimes there is a choice, e.g.relu
you can useΩ>0
orx>0
& the automatic rule will be efficient in different operations.
That's a good point. So to summarize the decision points are the following.
- How to signal what methods exist among
srule_in
,srule_out
,srule_out_in
? - How to define those methods?
For 1., maybe one could check whether methods derivatives_given_output(output::Nothing, f, inputs...)
and derivatives_given_output(output, f, inputs::Nothing)
exist. Or maybe derivatives_given_output(output, f, inputs::Nothing...)
? I'm not sure, maybe this would be cleaner if the main method was derivatives_given_output(output::Tuple, f, inputs::Tuple)
. Otherwise, actually having functions srule_in
, srule_out
, srule_out_in
also seems reasonable.
Question 2. is less clear. The most manual solution is to add to @scalar_rule
a keyword is_output_nullable
and is_input_nullable
. If that is supported, the rule will generate methods that allow Nothing
as inputs and outputs: it is up to the user to make sure that is handled correctly (must be tested of course). If eg Ω
does not appear in the formula, this will happen out of the box. For relu
, one would write a rule that reads isnothing(Ω) ? x > 0 : Ω > 0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I am leaning towards is just keeping the existing derivatives_given_output
, and checking whether it _uses_input_only
like this:
I guess the approach depends on being willing to call Core.Compiler._return_type
. It seems OK to me to do this as an optimisation -- a type-unstable derivatives_given_output
will just cause you go the slow path, but really these should be simple functions.
(Perhaps it should be a little more picky, are there cases where a rule actually returns the nothing
instead of multiplying it by a number? Multiplying is an error, hence gives Union{}
. Using nothing
for this is just the first idea I had, there might be a smarter choice, or a reason to own the type used.)
For the relu
story, just writing @scalar_rule relu(x::Real) isnothing(Ω) ? x > 0 : Ω > 0
will I think ensure you get both fast paths. Maybe that's confusing to read... but my vote re designing keyword options for the macro is to wait until we have enough cases that it's obviously worthwhile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, here's why nothing
isn't a great flag. Maybe we want a special struct for this:
julia> relu(x) = ifelse(x>0, x, zero(x));
julia> @scalar_rule relu(x::Real) isnothing(Ω) ? x > 0 : Ω > 0
julia> derivatives_given_output(nothing, relu, 1.0)
((true,),)
julia> derivatives_given_output(1.0, relu, nothing)
ERROR: MethodError: no method matching derivatives_given_output(::Float64, ::typeof(relu), ::Nothing)
Closest candidates are:
derivatives_given_output(::Any, ::typeof(relu), ::Real) at REPL[20]:1
julia> struct Bomb <: Real end
julia> derivatives_given_output(1.0, relu, Bomb())
((true,),)
julia> Core.Compiler._return_type(derivatives_given_output, Tuple{Bomb, typeof(sqrt), Float64})
Union{}
julia> Core.Compiler._return_type(derivatives_given_output, Tuple{Float64, typeof(sqrt), Bomb})
Tuple{Tuple{Float64}}
This is a variant of #453, which gives gradients for
Ω = f(x)
without computingf
, but also without knowingΩ
. The function is not defined if the rule in fact usesΩ
.The reason to do this is to speed up
sum(f,x)
, JuliaDiff/ChainRules.jl#529. Not sure it's worth the complication, but see what you think.Edit -- see discussion below, I currently think this (and the corresponding rule for when
x
is not needed, onlyΩ
) can just be done as methods ofderivatives_given_output
.