-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rule for prod
#335
Add rule for prod
#335
Changes from all commits
80de140
5622b92
87c43b9
dc36cbe
271231b
93e9b2c
36e3938
285e76e
5108c90
5d5ae0e
7f324fc
c9cb0c1
cf5f5f1
b892f75
c6d76d3
f766ad0
34d4586
f756149
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,4 +20,60 @@ | |
end | ||
end | ||
end # sum abs2 | ||
|
||
@testset "prod" begin | ||
@testset "Array{$T}" for T in [Float64, ComplexF64] | ||
@testset "size = $sz, dims = $dims" for (sz, dims) in [ | ||
((12,), :), ((12,), 1), | ||
((3,4), 1), ((3,4), 2), ((3,4), :), ((3,4), [1,2]), | ||
((3,4,1), 1), ((3,2,2), 3), ((3,2,2), 2:3), | ||
] | ||
x = randn(T, sz) | ||
test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true) | ||
x[1] = 0 | ||
test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true) | ||
x[5] = 0 | ||
test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true) | ||
x[3] = x[7] = 0 # two zeros along some slice, for any dims | ||
test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true) | ||
|
||
if ndims(x) == 3 | ||
xp = PermutedDimsArray(x, (3,2,1)) # not a StridedArray | ||
xpdot, xpbar = permutedims(rand(T, sz), (3,2,1)), permutedims(rand(T, sz), (3,2,1)) | ||
test_rrule(prod, xp ⊢ xpbar; fkwargs=(dims=dims,), check_inferred=true) | ||
end | ||
end | ||
|
||
@testset "structured wrappers" begin | ||
# Adjoint -- like PermutedDimsArray this may actually be used | ||
xa = adjoint(rand(T,4,4)) | ||
test_rrule(prod, xa ⊢ rand(T,4,4)) | ||
test_rrule(prod, xa ⊢ rand(T,4,4), fkwargs=(dims=2,)) | ||
@test unthunk(rrule(prod, adjoint(rand(T,3,3)))[2](1.0)[2]) isa Matrix | ||
@test unthunk(rrule(prod, adjoint(rand(T,3,3)), dims=1)[2](ones(1,3))[2]) isa Matrix | ||
|
||
# Diagonal -- a stupid thing to do, product of zeros! Shouldn't be an error though: | ||
@test iszero(unthunk(rrule(prod, Diagonal(rand(T,3)))[2](1.0)[2])) | ||
@test iszero(unthunk(rrule(prod, Diagonal(rand(T,3)), dims=1)[2](ones(1,3))[2])) | ||
@test unthunk(rrule(prod, Diagonal(rand(T,1)))[2](1.0)[2]) == hcat(1) # 1x1 sparse matrix | ||
@test unthunk(rrule(prod, Diagonal(ones(T,2)), dims=1)[2](ones(1,2))[2]) == [0 1; 1 0] | ||
|
||
# Triangular -- almost equally stupud | ||
@test iszero(unthunk(rrule(prod, UpperTriangular(rand(T,3,3)))[2](1.0)[2])) | ||
@test unthunk(rrule(prod, UpperTriangular(ones(T,2,2)))[2](1.0)[2]) == [0 0; 1 0] | ||
|
||
# Symmetric -- at least this doesn't have zeros, still an unlikely combination | ||
xs = Symmetric(rand(T,4,4)) | ||
@test_skip test_rrule(prod, xs ⊢ rand(T,4,4)) | ||
@test_skip test_rrule(prod, xs ⊢ rand(T,4,4), fkwargs=(dims=2,)) | ||
Comment on lines
+67
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a bug in how FiniteDifferences does this, or am I thinking incorrectly about what it should produce?
With
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest checking what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now I've left here a simpler test that it does run without error on a Symmetric: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be better to have a test with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I've switched it. It's mostly to check this doesn't give an error, the computation here does not care at all what kind of matrix it gets. Although
|
||
@test unthunk(rrule(prod, Symmetric(T[1 2; -333 4]))[2](1.0)[2]) == [16 8; 8 4] | ||
end | ||
end | ||
@testset "Array{Float32}, no zero entries" begin | ||
v = [1f-5, 1f-10, 1f-15, 1f-20] | ||
@test prod(v) == 0 | ||
@test unthunk(rrule(prod, v)[2](1f0)[2]) == zeros(4) | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
test_rrule(prod, v) | ||
end | ||
end # prod | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for diving in with my usual request, but is there a sensible way that we could restrict the type here, since the tests currently only look at
Array
s? e.g. I imagine that a at least one of aFill
,Diagonal
,StaticArray
etc will do something weird here. WouldStridedArray
suffice for your use case?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a test with PermutedDimsArray, which isn't a StridedArray, and I think it ought to work fine with StaticArrays, although I have not tested that in any depth. Diagonal seems to work although I struggle to imagine why calling
prod
on one would be a good idea, but weird things happen:Fill makes a Vector gradient. Somehow
rrule(sum, Fill(2,3))
makes a Fill, because it simply broadcasts rather than callingsimilar
. Is this something the package aims to guarantee? I don't see a test for it. Elsewhere it choosessimilar
over broadcasting to void other issues.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, it's late here. They are quite clearly in the tests...
We would definitely want the output of
rrule
w.r.t. aFill
argument to be either anotherFill
or an appropriateComposite
. This is the kind of thing that e.g.Zygote
can probably get right without a rule inChainRules
, so I think the ideal solution here is just not to implement a rule that coversFill
.Also, should the result with
Diagonal
have zeros on the diagonal?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something is broken on the last commit, but not this, this one you can do in your head: It's a product of mostly zeros, so the gradient with respect to nonzero entries still vanishes.
This could be arranged at the cost of more complexity... although possibly Fill ought to define
similar
more like that of Diagonal if it wishes to be preserved under such things?Although clearly not all gradients are going to perserve this structure:
And here's how much Zygote can figure out right now:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point.
This should produce a
Composite
, with thevalue
with an appropriatevalue
field.This should also produce either a
Diagonal
or an appropriateComposite
.But with all of these types, I'm not saying that your PR needs to cover them. I'm purely suggesting it addresses the minimal set of types that you're confident are done correctly, and assumes that AD can do a reasonable job of deriving the others.
The answer to this is the result of a bug in
Zygote
that I should fix -- it looks like an example of what I'm commenting on here, wheregetindex
has been implemented for too broad a set of types.Zygote
really should be able to derive the rule for this properly. i.e.getindex
only ever returns thevalue
field of aFill
, so you shouldn't even need a rule forgetindex
forFill
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we can get this merged, shall we change this to
StridedArray
and then we can make a follow up later?