Skip to content

Optimise a subset of parameters #35

Closed
@mcabbott

Description

@mcabbott

Flux's trainable works like this:

julia> Flux.trainable(BatchNorm(2, relu))  # this avoids half the parameters
(Float32[0.0, 0.0], Float32[1.0, 1.0])

julia> Functors.children(BatchNorm(2, relu))   # this sees them all, for |> gpu= NNlib.relu, β = Float32[0.0, 0.0], γ = Float32[1.0, 1.0], μ = Float32[0.0, 0.0], σ² = Float32[1.0, 1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 2)

This doesn't seem great, it relies on objectid to know which parameters those really are. So this:

function _trainable_walk(f, x)
  func, re = functor(x)
  nb = trainable(x)
  re(map(c -> c in nb ? f(c) : c, func))
end

will not work correctly for say β === SA[0.0, 0.0] === μ.

How should it work?

  • One idea would be to clone the @functor macro to have @trainable BatchNorm (β, γ)? In fact this case is even worse, it checks a value here but we could probably move affine into the type.

  • Another idea would be just to have trainable(:: BatchNorm) = (:β, :γ) the symbols. That's much easier to write and perhaps less mysterious. Might be slower, do we care? Or might not be, if the symbols are known from the type. It would be easy here to allow Flux-style tuples as a fallback, detecting NTuple{Symbol} etc, making it easier to have both old- and new-style at once.

This would be used during setup, just one pass. After that, the tree of optimiser states should tell you whether or not to update a given array, so update need never call this.

What might call it more often is destructure, which I think we want to walk only the trainable parameters, and will sometimes be called in a loop.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions