-
-
Notifications
You must be signed in to change notification settings - Fork 216
Closed
Description
I am struggling with defining custom rrule and computing gradients a custom AbstractMatrix subtype.
For simplicity, assume naive matmul wrapped in a struct
using Flux, Zygote, ChainRulesCore
struct MyMatrix{T, U <: AbstractMatrix{T}} <: AbstractMatrix{T}
W::U
end
function Base.:*(A::MyMatrix, B::AbstractMatrix)
C = zeros(size(A.W, 1), size(B, 2))
for i = 1:size(A.W, 1), j = 1:size(B, 2), k = 1:size(A.W, 2)
C[i,j] += A.W[i,k] * B[k,j]
end
return C
end
function ChainRulesCore.rrule(::typeof(Base.:*), A::MyMatrix, B::AbstractMatrix)
A.W * B, Δ -> (NoTangent(), NoTangent(), NoTangent())
end
W, B = rand(2,2), rand(2,2)
@assert MyMatrix(W) * B ≈ W * B
gradient(A -> sum(A*B), MyMatrix(W))The last line gives
julia> gradient(B -> sum(MyMatrix(W)*B), B)
ERROR: type Nothing has no field method
Stacktrace:
[1] getproperty
@ ./Base.jl:42 [inlined]
[2] matching_cr_sig(t::IRTools.Inner.Meta, s::Nothing)
@ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/chainrules.jl:72
[3] has_chain_rrule(T::Type)
@ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/chainrules.jl:54
[4] #s3060#1217
@ ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:20 [inlined]
[5] var"#s3060#1217"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote ./none:0
[6] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:580
[7] _pullback
@ ./REPL[8]:1 [inlined]
[8] _pullback(ctx::Zygote.Context, f::var"#5#6", args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
[9] _pullback(f::Function, args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:34
[10] pullback(f::Function, args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:40
[11] gradient(f::Function, args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:75
[12] top-level scope
@ REPL[8]:1
[13] top-level scope
@ ~/.julia/packages/CUDA/sCev8/src/initialization.jl:52
What is the correct way to define rrule for Base.:* in situations like this one?
racinmat
Metadata
Metadata
Assignees
Labels
No labels