-
-
Notifications
You must be signed in to change notification settings - Fork 31
Fix AD for parameters #175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,37 +3,47 @@ using Integrals | |
if isdefined(Base, :get_extension) | ||
using Zygote | ||
import ChainRulesCore | ||
import ChainRulesCore: NoTangent | ||
import ChainRulesCore: NoTangent, ProjectTo | ||
else | ||
using ..Zygote | ||
import ..Zygote.ChainRulesCore | ||
import ..Zygote.ChainRulesCore: NoTangent | ||
import ..Zygote.ChainRulesCore: NoTangent, ProjectTo | ||
end | ||
ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...) | ||
ChainRulesCore.@non_differentiable Integrals.isinplace(f, n) # fixes #99 | ||
|
||
function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub, | ||
p; | ||
kwargs...) | ||
out = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...) | ||
|
||
# the adjoint will be the integral of the input sensitivities, so it maps the | ||
# sensitivity of the output to an object of the type of the parameters | ||
function quadrature_adjoint(Δ) | ||
y = typeof(Δ) <: Array{<:Number, 0} ? Δ[1] : Δ | ||
# https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes | ||
y = cache.nout == 1 ? Δ[1] : Δ # interpret the output as scalar | ||
# this will not be type-stable, but I believe it is unavoidable due to two ambiguities: | ||
# 1. Δ is the output of the algorithm, and when nout = 1 it is undefined whether the | ||
# output of the algorithm must be a scalar or a vector of length 1 | ||
# 2. when nout = 1 the integrand can either be a scalar or a vector of length 1 | ||
if isinplace(cache) | ||
dx = zeros(cache.nout) | ||
_f = x -> cache.f(dx, x, p) | ||
if sensealg.vjp isa Integrals.ZygoteVJP | ||
dfdp = function (dx, x, p) | ||
_, back = Zygote.pullback(p) do p | ||
_dx = Zygote.Buffer(x, cache.nout, size(x, 2)) | ||
z, back = Zygote.pullback(p) do p | ||
_dx = cache.nout == 1 ? | ||
Zygote.Buffer(dx, eltype(y), size(x, ndims(x))) : | ||
Zygote.Buffer(dx, eltype(y), cache.nout, size(x, ndims(x))) | ||
cache.f(_dx, x, p) | ||
copy(_dx) | ||
end | ||
|
||
z = zeros(size(x, 2)) | ||
for idx in 1:size(x, 2) | ||
z[1] = 1 | ||
dx[:, idx] = back(z)[1] | ||
z[idx] = 0 | ||
z .= zero(eltype(z)) | ||
for idx in 1:size(x, ndims(x)) | ||
z isa Vector ? (z[idx] = y) : (z[:, idx] .= y) | ||
dx[:, idx] .= back(z)[1] | ||
z isa Vector ? (z[idx] = zero(eltype(z))) : | ||
(z[:, idx] .= zero(eltype(z))) | ||
end | ||
end | ||
elseif sensealg.vjp isa Integrals.ReverseDiffVJP | ||
|
@@ -44,14 +54,21 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal | |
if sensealg.vjp isa Integrals.ZygoteVJP | ||
if cache.batch > 0 | ||
dfdp = function (x, p) | ||
_, back = Zygote.pullback(p -> cache.f(x, p), p) | ||
z, back = Zygote.pullback(p -> cache.f(x, p), p) | ||
# messy, there are 4 cases, some better in forward mode than reverse | ||
# 1: length(y) == 1 and length(p) == 1 | ||
# 2: length(y) > 1 and length(p) == 1 | ||
# 3: length(y) == 1 and length(p) > 1 | ||
# 4: length(y) > 1 and length(p) > 1 | ||
Comment on lines
+58
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure what this comment is here for? I mean, agreed these are the 4 cases and sometimes forward is better than reverse, but I don't understand why that's here 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah sorry, I just left that as a note to myself and I can remove it. The lines of code below need all the if statements to handle these cases, so I found it helpful to list them |
||
|
||
out = zeros(length(p), size(x, 2)) | ||
z = zeros(size(x, 2)) | ||
for idx in 1:size(x, 2) | ||
z[idx] = 1 | ||
out[:, idx] = back(z)[1] | ||
z[idx] = 0 | ||
z .= zero(eltype(z)) | ||
out = zeros(eltype(p), size(p)..., size(x, ndims(x))) | ||
for idx in 1:size(x, ndims(x)) | ||
z isa Vector ? (z[idx] = y) : (z[:, idx] .= y) | ||
out isa Vector ? (out[idx] = back(z)[1]) : | ||
(out[:, idx] .= back(z)[1]) | ||
z isa Vector ? (z[idx] = zero(y)) : | ||
(z[:, idx] .= zero(eltype(y))) | ||
end | ||
out | ||
end | ||
|
@@ -76,17 +93,30 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal | |
do_inf_transformation = Val(false), | ||
cache.kwargs...) | ||
|
||
if p isa Number | ||
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1] | ||
else | ||
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...).u | ||
end | ||
project_p = ProjectTo(p) | ||
dp = project_p(Integrals.__solvebp_call(dp_cache, | ||
alg, | ||
sensealg, | ||
lb, | ||
ub, | ||
p; | ||
kwargs...).u) | ||
|
||
if lb isa Number | ||
dlb = -_f(lb) | ||
dub = _f(ub) | ||
dlb = cache.batch > 0 ? -_f([lb]) : -_f(lb) | ||
dub = cache.batch > 0 ? _f([ub]) : _f(ub) | ||
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dlb, dub, dp) | ||
else | ||
# we need to compute 2*length(lb) integrals on the faces of the hypercube, as we | ||
# can see from writing the multidimensional integral as an iterated integral | ||
# alternatively we can use Stokes' theorem to replace the integral on the | ||
# boundary with a volume integral of the flux of the integrand | ||
# ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the | ||
# dimensionality of the integral or the quadrature used (such as quadratures | ||
# that don't evaluate points on the boundaries) and it could be generalized to | ||
# other kinds of domains. The only question is to determine ω in terms of f and | ||
# the deformation of the surface (e.g. consider integral over an ellipse and | ||
# asking for the derivative of the result w.r.t. the semiaxes of the ellipse) | ||
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), | ||
NoTangent(), dp) | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't the nout = 1 case not be able to use
similar
because it could be scalar?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
p
is an array (see function signature), so arerawp
(look below) anddualp
, so I thinksimilar
will be defined.