You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
julia> N =3;
julia> D =Diagonal(randn(N));
julia> T =Tangent{Diagonal}(diag=rand(N));
julia>test_rrule(Diagonal, randn(N); output_tangent=D)
Test Summary:| Pass Total
test_rrule: Diagonal on Vector{Float64} |77
Test.DefaultTestSet("test_rrule: Diagonal on Vector{Float64}", Any[], 7, false, false)
julia>test_rrule(Diagonal, randn(N); output_tangent=T)
test_rrule: Diagonal on Vector{Float64}: Error During Test at /Users/mzgubic/JuliaEnvs/ChainRules.jl/dev/ChainRulesTestUtils/src/testers.jl:191
Got exception outside of a @testDimensionMismatch("second dimension of A, 9, does not match length of x, 3")
What is happening here? The error comes from computing the cotangents using FiniteDifferences. It works by computing the Jacobian, and then taking the vector-Jacobian product with the vector representing the output tangent.
Since the output is a dense 3-by-3 matrix, the columns of the Jacobian have length 9. They need to be multiplied by a vector of length 9. Now compare
julia> v, b =to_vec(D); v
9-element Vector{Float64}:-0.100412024319943930.00.00.01.3445602228427340.00.00.00.47457750566464907
julia> v, b =to_vec(T); v
3-element Vector{Float64}:0.79293729123001030.97319881272530420.39340952399368345
We have chosen to densify the Diagonal matrix in to_vec (partly for this exact reason, see e.g. JuliaDiff/FiniteDifferences.jl#186) and so we get a vector of length 9. On the other hand, the Tangent is converted to a vector by converting each of the fields, which is only of length 3.
Note that this is not just a matter of choosing whether to densify structural types and being consistent. If we e.g. choose to not densify the Diagonal, we run into a similar problem when we want to do test_rrule(*, rand(2, 2), rand(2, 2); output_tangent=Diagonal(rand(2))), because now the output is a Matrix which is a vector of length 4, and the tangent is a Diagonal, which is a vector of length 2 if we don't densify it.
A partial solution to this could be using the ProjectTo mechanism from ChainRulesCore. In the first example, we would project the tangent type onto the primal output (which is of type Diagonal) and the problem is solved. However, in the second case, the ProjectTo(rand(2, 2))(Diagonal(rand)) would leave the Diagonal matrix intact since the Diagonal matrix is in the subspace defined by the Matrix type.
The text was updated successfully, but these errors were encountered:
Consider
What is happening here? The error comes from computing the cotangents using FiniteDifferences. It works by computing the Jacobian, and then taking the vector-Jacobian product with the vector representing the output tangent.
Since the output is a dense 3-by-3 matrix, the columns of the Jacobian have length 9. They need to be multiplied by a vector of length 9. Now compare
We have chosen to densify the
Diagonal
matrix into_vec
(partly for this exact reason, see e.g. JuliaDiff/FiniteDifferences.jl#186) and so we get a vector of length 9. On the other hand, theTangent
is converted to a vector by converting each of the fields, which is only of length 3.Note that this is not just a matter of choosing whether to densify structural types and being consistent. If we e.g. choose to not densify the Diagonal, we run into a similar problem when we want to do
test_rrule(*, rand(2, 2), rand(2, 2); output_tangent=Diagonal(rand(2)))
, because now the output is aMatrix
which is a vector of length 4, and the tangent is aDiagonal
, which is a vector of length 2 if we don't densify it.A partial solution to this could be using the
ProjectTo
mechanism from ChainRulesCore. In the first example, we would project the tangent type onto the primal output (which is of typeDiagonal
) and the problem is solved. However, in the second case, theProjectTo(rand(2, 2))(Diagonal(rand))
would leave theDiagonal
matrix intact since theDiagonal
matrix is in the subspace defined by theMatrix
type.The text was updated successfully, but these errors were encountered: