Skip to content

Unnecessary typeasserts in Flux.Optimise.apply! cause training to fail #816

Closed
@kernelmethod

Description

@kernelmethod

Hi, all --

Here's an MWE where I try to learn a linear transform using a custom AbstractMatrix subtype:

using Flux
import Zygote: @adjoint, forward, gradient

struct MyMatrix{T, A <: AbstractMatrix{T}} <: AbstractMatrix{T}
    W :: A
    
    MyMatrix(W::AbstractMatrix) = begin
        # This example is a little contrived -- MyMatrix just wraps around copy(W')'.
        # If we skip this line then the rest of this code works fine.
        W = copy(W')'
        new{eltype(W),typeof(W)}(W)
    end
end

Flux.params(A::MyMatrix) = Flux.Params([A])
Tracker.istracked(::MyMatrix) = true   # Make Tracker think MyMatrix is a TrackedArray

Base.size(A::MyMatrix, args...) = size(A.W, args...)
Base.getindex(A::MyMatrix, args...) = getindex(A.W, args...)
Base.setindex!(A::MyMatrix, args...) = setindex!(A.W, args...)
Base.:(*)(A::MyMatrix, x::AbstractVector) = A.W * x
Base.:(*)(A::MyMatrix, x::AbstractMatrix) = A.W * x
@adjoint Base.:(*)(A::MyMatrix, x::AbstractVector) = forward(Base.:(*), A.W, x)
@adjoint Base.:(*)(A::MyMatrix, x::AbstractMatrix) = forward(Base.:(*), A.W, x)

model = Dense(randn(8,8) |> MyMatrix, zeros(8) |> param, identity)

loss(x,y) = Flux.mse(model(x),y)
∇ = gradient(() -> loss(randn(8), randn(8)), params(model))

### THIS WORKS:
Flux.Optimise.update!(ADAM(), params(model), ∇)

### THIS FAILS:
Flux.Optimise.update!(RMSProp(), params(model), ∇)

The error caused by the last line is due to a typeassert in apply!(o::RMSProp, x, Δ). apply!(o::ADAM, x, Δ) doesn't contain any typeasserts and thus works fine.

There are similar typeasserts in the definitions of apply! for Momentum, Nesterov, and ADAGrad, but not in any of the other optimizers. Would it be possible to get rid of these? Or are they actually necessary?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions