Skip to content

Taking subspaces seriously #199

Closed
Closed
@mzgubic

Description

@mzgubic

Consider

julia> N = 3;

julia> D = Diagonal(randn(N));

julia> T = Tangent{Diagonal}(diag=rand(N));

julia> test_rrule(Diagonal, randn(N); output_tangent=D)
Test Summary:                           | Pass  Total
test_rrule: Diagonal on Vector{Float64} |    7      7
Test.DefaultTestSet("test_rrule: Diagonal on Vector{Float64}", Any[], 7, false, false)

julia> test_rrule(Diagonal, randn(N); output_tangent=T)
test_rrule: Diagonal on Vector{Float64}: Error During Test at /Users/mzgubic/JuliaEnvs/ChainRules.jl/dev/ChainRulesTestUtils/src/testers.jl:191
  Got exception outside of a @test
  DimensionMismatch("second dimension of A, 9, does not match length of x, 3")

What is happening here? The error comes from computing the cotangents using FiniteDifferences. It works by computing the Jacobian, and then taking the vector-Jacobian product with the vector representing the output tangent.

Since the output is a dense 3-by-3 matrix, the columns of the Jacobian have length 9. They need to be multiplied by a vector of length 9. Now compare

julia> v, b = to_vec(D); v
9-element Vector{Float64}:
 -0.10041202431994393
  0.0
  0.0
  0.0
  1.344560222842734
  0.0
  0.0
  0.0
  0.47457750566464907

julia> v, b = to_vec(T); v
3-element Vector{Float64}:
 0.7929372912300103
 0.9731988127253042
 0.39340952399368345

We have chosen to densify the Diagonal matrix in to_vec (partly for this exact reason, see e.g. JuliaDiff/FiniteDifferences.jl#186) and so we get a vector of length 9. On the other hand, the Tangent is converted to a vector by converting each of the fields, which is only of length 3.

Note that this is not just a matter of choosing whether to densify structural types and being consistent. If we e.g. choose to not densify the Diagonal, we run into a similar problem when we want to do test_rrule(*, rand(2, 2), rand(2, 2); output_tangent=Diagonal(rand(2))), because now the output is a Matrix which is a vector of length 4, and the tangent is a Diagonal, which is a vector of length 2 if we don't densify it.

A partial solution to this could be using the ProjectTo mechanism from ChainRulesCore. In the first example, we would project the tangent type onto the primal output (which is of type Diagonal) and the problem is solved. However, in the second case, the ProjectTo(rand(2, 2))(Diagonal(rand)) would leave the Diagonal matrix intact since the Diagonal matrix is in the subspace defined by the Matrix type.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions