Skip to content

Add derivatives_given_input #456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Sep 12, 2021

This is a variant of #453, which gives gradients for Ω = f(x) without computing f, but also without knowing Ω. The function is not defined if the rule in fact uses Ω.

The reason to do this is to speed up sum(f,x), JuliaDiff/ChainRules.jl#529. Not sure it's worth the complication, but see what you think.

Edit -- see discussion below, I currently think this (and the corresponding rule for when x is not needed, only Ω) can just be done as methods of derivatives_given_output.

@codecov-commenter
Copy link

codecov-commenter commented Sep 12, 2021

Codecov Report

Merging #456 (387f4e2) into master (47389f5) will decrease coverage by 0.05%.
The diff coverage is 88.88%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #456      +/-   ##
==========================================
- Coverage   92.94%   92.89%   -0.06%     
==========================================
  Files          14       14              
  Lines         794      802       +8     
==========================================
+ Hits          738      745       +7     
- Misses         56       57       +1     
Impacted Files Coverage Δ
src/rule_definition_tools.jl 95.73% <88.88%> (-0.43%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 47389f5...387f4e2. Read the comment docs.

return @strip_linenos quote
function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...))
given_output = @strip_linenos quote
@inline function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inlining here is probably a good idea regardless of whether one wants derivatives_given_input. Maybe it can be added in a separate PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot about that. Will fiddle a bit more to see if these matter.

Comment on lines +195 to +201
given_input = @strip_linenos quote
@inline function ChainRulesCore.derivatives_given_input(::Core.Typeof($f), $(inputs...))
$(__source__)
$(setup_stmts...)
return $(Expr(:tuple, partials...))
end
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases it's it might also be useful to know if the input is no longer necessary given the outputs. Maybe a more general solution could be to have something like a trait needed_for_derivatives(Ω, f, inputs...) returning a pair of booleans (whether the input is used and whether the output is used).

Then things like sum could check exactly what variables derivatives_given_output requires, and fill everything else with nothing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Re not needing the input, my example here is FluxML/NNlib.jl#346 where I want activate!(f, x) = f.(x) which is free to over-write x. It has, right now, a hard-coded list of functions. Are there other uses for this?

Re how to tell, agree we could have some indicator beyond checking for methods. In JuliaDiff/ChainRules.jl#529 it seems convenient to do this based on f, typeof(x). I think, but have not checked, that using eltype not first will avoid problems with CuArrays. Although there are other ways.

For some functions you have a choice about whether to write it with Ω or x or both. I don't know how far down that path we want to go. When is re-computing Ω more efficient than saving it, do we need a cost model? Right now if it just looks for the presence of symbol .

Copy link
Member Author

@mcabbott mcabbott Sep 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this:

sum could check exactly what variables derivatives_given_output requires, and fill everything else with nothing.

are you thinking about something like sum(sincos, xs)? i.e. a function with multiple returns? No, that won't work, as you can't add tuples, it would have to be mapreduce(sincos, tadd, xs). My inclination is to say this is too obscure, multiple-return functions are pretty rare (and IMO would ideally not be handled by @scalar_rule, there are so few that doing them by hand might even save lines of code).

For mapreduce(f, +, x, y, z) I guess knowing that the gradient doesn't need (say) z could let you avoid closing over that. But also seems a little obscure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I worded that incorrectly. I meant to say that needed_for_derivatives(Ω, f, inputs...) would mention whether the output overall is used and whether the input overall is used (not differentiating between different components of the output or different inputs). The main use case for "does not need inputs" is indeed fused linear + pointwise nonlinearity kernels (eg Dense or Conv). It is only one use case, but it's the key component of most neural networks, so it probably deserves this performance optimization.

I was not completely sure whether it would make sense to distinguish whether a @scalar_rule needs all the inputs or only some of them, but I am also coming to the conclusion that differentiating between those would be too obscure.

From a practical perspective, I wonder how easy it is by looking at the expression in @scalar_rule to tell that the inputs are not used. There may be some annoyances due to destructuring, but it should be possible to figure out whether some of the inputs are used or not.

Copy link
Member Author

@mcabbott mcabbott Sep 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

One advantage of just checking whether methods exist, instead of having a special indicator function, is that it's harder to lie. We will probably want some methods not defined through the macro, such as:

https://github.com/JuliaDiff/ChainRules.jl/pull/529/files#diff-a5d2f9e38b98bb95c73ac64e97966ab6eaa897bd7ee10e1ea73ffb550d9b4d5cR83-R84

One more issue with the business of the macro looking for Omega (or x) in the expressions is that sometimes there is a choice, e.g. relu you can use Ω>0 or x>0 & the automatic rule will be efficient in different operations. Although maybe such functions are rare and we can just occasionally add a method by hand.

If you do have an indicator needed_for_derivatives, then you could just have one function and call derivatives_given_stuff(nothing, f, x) or (y, f, nothing). If you don't, then there need to be 3 functions, derivatives_given_input, derivatives_given_output, derivatives_given_both. (Although I'd vote for shorter names... can these be variants of srule, s for scalar? srule_in, srule_out, srule_out_in, or something.)

Copy link
Member Author

@mcabbott mcabbott Sep 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, even with just one function, if the way you check whether a method exists is type inference, then it will already tell you whether the result is used:

julia> Core.Compiler._return_type(derivatives_given_output, Tuple{Nothing, typeof(log), Float64})
Tuple{Tuple{Float64}}

julia> Core.Compiler._return_type(derivatives_given_output, Tuple{Nothing, typeof(sqrt), Float64})
Union{}

JuliaDiff/ChainRules.jl#529 now does this, instead of using this PR.

Won't help with the two-rules-for-relu story, though. You can define e.g. derivatives_given_output(::Nothing, ::typeof(relu), x::Real) to target one use or the other. If everyone uses nothing as the marker.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more issue with the business of the macro looking for Omega (or x) in the expressions is that sometimes there is a choice, e.g. relu you can use Ω>0 or x>0 & the automatic rule will be efficient in different operations.

That's a good point. So to summarize the decision points are the following.

  1. How to signal what methods exist among srule_in, srule_out, srule_out_in?
  2. How to define those methods?

For 1., maybe one could check whether methods derivatives_given_output(output::Nothing, f, inputs...) and derivatives_given_output(output, f, inputs::Nothing) exist. Or maybe derivatives_given_output(output, f, inputs::Nothing...)? I'm not sure, maybe this would be cleaner if the main method was derivatives_given_output(output::Tuple, f, inputs::Tuple). Otherwise, actually having functions srule_in, srule_out, srule_out_in also seems reasonable.

Question 2. is less clear. The most manual solution is to add to @scalar_rule a keyword is_output_nullable and is_input_nullable. If that is supported, the rule will generate methods that allow Nothing as inputs and outputs: it is up to the user to make sure that is handled correctly (must be tested of course). If eg Ω does not appear in the formula, this will happen out of the box. For relu, one would write a rule that reads isnothing(Ω) ? x > 0 : Ω > 0.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I am leaning towards is just keeping the existing derivatives_given_output, and checking whether it _uses_input_only like this:

https://github.com/JuliaDiff/ChainRules.jl/pull/529/files#diff-9e84213ce96cf861f2b18b2ee71e1dcc4bc9969020802f99705154e13204a1d3R109-R112

I guess the approach depends on being willing to call Core.Compiler._return_type. It seems OK to me to do this as an optimisation -- a type-unstable derivatives_given_output will just cause you go the slow path, but really these should be simple functions.

(Perhaps it should be a little more picky, are there cases where a rule actually returns the nothing instead of multiplying it by a number? Multiplying is an error, hence gives Union{}. Using nothing for this is just the first idea I had, there might be a smarter choice, or a reason to own the type used.)

For the relu story, just writing @scalar_rule relu(x::Real) isnothing(Ω) ? x > 0 : Ω > 0 will I think ensure you get both fast paths. Maybe that's confusing to read... but my vote re designing keyword options for the macro is to wait until we have enough cases that it's obviously worthwhile.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, here's why nothing isn't a great flag. Maybe we want a special struct for this:

julia> relu(x) = ifelse(x>0, x, zero(x));

julia> @scalar_rule relu(x::Real) isnothing(Ω) ? x > 0 : Ω > 0

julia> derivatives_given_output(nothing, relu, 1.0)
((true,),)

julia> derivatives_given_output(1.0, relu, nothing)
ERROR: MethodError: no method matching derivatives_given_output(::Float64, ::typeof(relu), ::Nothing)
Closest candidates are:
  derivatives_given_output(::Any, ::typeof(relu), ::Real) at REPL[20]:1

julia> struct Bomb <: Real end

julia> derivatives_given_output(1.0, relu, Bomb())
((true,),)

julia> Core.Compiler._return_type(derivatives_given_output, Tuple{Bomb, typeof(sqrt), Float64})
Union{}

julia> Core.Compiler._return_type(derivatives_given_output, Tuple{Float64, typeof(sqrt), Bomb})
Tuple{Tuple{Float64}}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants