Skip to content

Error in gradient of custom AbstractMatrix subtypes #1146

@simonmandlik

Description

@simonmandlik

I am struggling with defining custom rrule and computing gradients a custom AbstractMatrix subtype.

For simplicity, assume naive matmul wrapped in a struct

using Flux, Zygote, ChainRulesCore

struct MyMatrix{T, U <: AbstractMatrix{T}} <: AbstractMatrix{T}
    W::U
end

function Base.:*(A::MyMatrix, B::AbstractMatrix)
    C = zeros(size(A.W, 1), size(B, 2))
    for i = 1:size(A.W, 1), j = 1:size(B, 2), k = 1:size(A.W, 2)
        C[i,j] += A.W[i,k] * B[k,j]
    end
    return C
end

function ChainRulesCore.rrule(::typeof(Base.:*), A::MyMatrix, B::AbstractMatrix)
    A.W * B, Δ -> (NoTangent(), NoTangent(), NoTangent())
end

W, B = rand(2,2), rand(2,2)
@assert MyMatrix(W) * B  W * B
gradient(A -> sum(A*B), MyMatrix(W))

The last line gives

julia> gradient(B -> sum(MyMatrix(W)*B), B)
ERROR: type Nothing has no field method
Stacktrace:
  [1] getproperty
    @ ./Base.jl:42 [inlined]
  [2] matching_cr_sig(t::IRTools.Inner.Meta, s::Nothing)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/chainrules.jl:72
  [3] has_chain_rrule(T::Type)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/chainrules.jl:54
  [4] #s3060#1217
    @ ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:20 [inlined]
  [5] var"#s3060#1217"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [6] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
  [7] _pullback
    @ ./REPL[8]:1 [inlined]
  [8] _pullback(ctx::Zygote.Context, f::var"#5#6", args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
  [9] _pullback(f::Function, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:34
 [10] pullback(f::Function, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:40
 [11] gradient(f::Function, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:75
 [12] top-level scope
    @ REPL[8]:1
 [13] top-level scope
    @ ~/.julia/packages/CUDA/sCev8/src/initialization.jl:52

What is the correct way to define rrule for Base.:* in situations like this one?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions