-
Notifications
You must be signed in to change notification settings - Fork 64
Add derivatives_given_input
#456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -158,6 +158,7 @@ end | |
|
||
Compute the derivative of scalar function `f` at primal input point `xs...`, | ||
given that it had primal output `Ω`. | ||
|
||
Return a tuple of tuples with the partial derivatives of `f` with respect to the `xs...`. | ||
The derivative of the `i`-th component of `f` with respect to the `j`-th input can be | ||
accessed as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`. | ||
|
@@ -173,16 +174,43 @@ accessed as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`. | |
""" | ||
function derivatives_given_output end | ||
|
||
""" | ||
derivatives_given_output(f, xs...) | ||
|
||
Compute the derivative of scalar function `f` at primal input point `xs...`, | ||
when this is possible *without* knowing primal output `Ω`. | ||
|
||
!!! warning "Experimental" | ||
""" | ||
function derivatives_given_input end | ||
|
||
function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) | ||
return @strip_linenos quote | ||
function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) | ||
given_output = @strip_linenos quote | ||
@inline function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) | ||
$(__source__) | ||
$(setup_stmts...) | ||
return $(Expr(:tuple, partials...)) | ||
end | ||
end | ||
given_input = @strip_linenos quote | ||
@inline function ChainRulesCore.derivatives_given_input(::Core.Typeof($f), $(inputs...)) | ||
$(__source__) | ||
$(setup_stmts...) | ||
return $(Expr(:tuple, partials...)) | ||
end | ||
end | ||
Comment on lines
+195
to
+201
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In some cases it's it might also be useful to know if the input is no longer necessary given the outputs. Maybe a more general solution could be to have something like a trait Then things like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Re not needing the input, my example here is FluxML/NNlib.jl#346 where I want Re how to tell, agree we could have some indicator beyond checking for methods. In JuliaDiff/ChainRules.jl#529 it seems convenient to do this based on For some functions you have a choice about whether to write it with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this:
are you thinking about something like For There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I worded that incorrectly. I meant to say that I was not completely sure whether it would make sense to distinguish whether a From a practical perspective, I wonder how easy it is by looking at the expression in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. One advantage of just checking whether methods exist, instead of having a special indicator function, is that it's harder to lie. We will probably want some methods not defined through the macro, such as: One more issue with the business of the macro looking for Omega (or x) in the expressions is that sometimes there is a choice, e.g. If you do have an indicator There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, even with just one function, if the way you check whether a method exists is type inference, then it will already tell you whether the result is used:
JuliaDiff/ChainRules.jl#529 now does this, instead of using this PR.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's a good point. So to summarize the decision points are the following.
For 1., maybe one could check whether methods Question 2. is less clear. The most manual solution is to add to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I am leaning towards is just keeping the existing I guess the approach depends on being willing to call (Perhaps it should be a little more picky, are there cases where a rule actually returns the For the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, here's why julia> relu(x) = ifelse(x>0, x, zero(x));
julia> @scalar_rule relu(x::Real) isnothing(Ω) ? x > 0 : Ω > 0
julia> derivatives_given_output(nothing, relu, 1.0)
((true,),)
julia> derivatives_given_output(1.0, relu, nothing)
ERROR: MethodError: no method matching derivatives_given_output(::Float64, ::typeof(relu), ::Nothing)
Closest candidates are:
derivatives_given_output(::Any, ::typeof(relu), ::Real) at REPL[20]:1
julia> struct Bomb <: Real end
julia> derivatives_given_output(1.0, relu, Bomb())
((true,),)
julia> Core.Compiler._return_type(derivatives_given_output, Tuple{Bomb, typeof(sqrt), Float64})
Union{}
julia> Core.Compiler._return_type(derivatives_given_output, Tuple{Float64, typeof(sqrt), Bomb})
Tuple{Tuple{Float64}} |
||
return if _free_of_omega([inputs, partials]) | ||
:($given_output; $given_input) | ||
else | ||
given_output | ||
end | ||
end | ||
|
||
_free_of_omega(v::Union{Vector,Tuple}) = all(_free_of_omega, v) | ||
_free_of_omega(ex::Expr) = _free_of_omega(ex.args) | ||
_free_of_omega(s::Symbol) = s != :Ω | ||
_free_of_omega(other) = true # (@show other typeof(other); true) | ||
|
||
function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) | ||
n_outputs = length(partials) | ||
n_inputs = length(inputs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inlining here is probably a good idea regardless of whether one wants
derivatives_given_input
. Maybe it can be added in a separate PR?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot about that. Will fiddle a bit more to see if these matter.