Closed
Description
Hi, all --
Here's an MWE where I try to learn a linear transform using a custom AbstractMatrix
subtype:
using Flux
import Zygote: @adjoint, forward, gradient
struct MyMatrix{T, A <: AbstractMatrix{T}} <: AbstractMatrix{T}
W :: A
MyMatrix(W::AbstractMatrix) = begin
# This example is a little contrived -- MyMatrix just wraps around copy(W')'.
# If we skip this line then the rest of this code works fine.
W = copy(W')'
new{eltype(W),typeof(W)}(W)
end
end
Flux.params(A::MyMatrix) = Flux.Params([A])
Tracker.istracked(::MyMatrix) = true # Make Tracker think MyMatrix is a TrackedArray
Base.size(A::MyMatrix, args...) = size(A.W, args...)
Base.getindex(A::MyMatrix, args...) = getindex(A.W, args...)
Base.setindex!(A::MyMatrix, args...) = setindex!(A.W, args...)
Base.:(*)(A::MyMatrix, x::AbstractVector) = A.W * x
Base.:(*)(A::MyMatrix, x::AbstractMatrix) = A.W * x
@adjoint Base.:(*)(A::MyMatrix, x::AbstractVector) = forward(Base.:(*), A.W, x)
@adjoint Base.:(*)(A::MyMatrix, x::AbstractMatrix) = forward(Base.:(*), A.W, x)
model = Dense(randn(8,8) |> MyMatrix, zeros(8) |> param, identity)
loss(x,y) = Flux.mse(model(x),y)
∇ = gradient(() -> loss(randn(8), randn(8)), params(model))
### THIS WORKS:
Flux.Optimise.update!(ADAM(), params(model), ∇)
### THIS FAILS:
Flux.Optimise.update!(RMSProp(), params(model), ∇)
The error caused by the last line is due to a typeassert in apply!(o::RMSProp, x, Δ)
. apply!(o::ADAM, x, Δ)
doesn't contain any typeasserts and thus works fine.
There are similar typeasserts in the definitions of apply!
for Momentum
, Nesterov
, and ADAGrad
, but not in any of the other optimizers. Would it be possible to get rid of these? Or are they actually necessary?
Thanks!
Metadata
Metadata
Assignees
Labels
No labels