Skip to content

Commit

Permalink
Merge #865
Browse files Browse the repository at this point in the history
865: Functor r=MikeInnes a=MikeInnes

This refactors our current `@treelike` infrastructure. It somewhat formalises what we're doing around the idea of a Flux model as a functor, i.e. something that can be mapped over.

This is much more flexible than what we had before, and avoids some issues. It allows layers to have state that isn't mappable; it allows for dispatch when walking the tree, which means layers like `BatchNorm` can have non-trainable parameters; and it also allows for zipped mapping like `fmap(+, xs, ys)`, which isn't implemented yet but will be useful for the new optimisers work.

The main downside is that the term `functor` has been previously used in the Julia community as a malapropism for "thing that behaves like a function"; but hopefully this can start to reduce that usage.

Co-authored-by: Mike Innes <mike.j.innes@gmail.com>
  • Loading branch information
bors[bot] and MikeInnes authored Sep 19, 2019
2 parents b60df53 + cabb81e commit 797c39d
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 127 deletions.
6 changes: 3 additions & 3 deletions docs/src/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ loss(x, y) # ~ 3

Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.

If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.

```julia
d = Dense(10, 5, σ)
d = mapleaves(cu, d)
d = fmap(cu, d)
d.W # Tracked CuArray
d(cu(rand(10))) # CuArray output

m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
m = mapleaves(cu, m)
m = fmap(cu, m)
d(cu(rand(10)))
```

Expand Down
2 changes: 1 addition & 1 deletion docs/src/models/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ m(5) # => 26
Flux provides a set of helpers for custom layers, which you can enable by calling

```julia
Flux.@treelike Affine
Flux.@functor Affine
```

This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
4 changes: 2 additions & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export gradient

export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, mapleaves, cpu, gpu, f32, f64
SkipConnection, params, fmap, cpu, gpu, f32, f64

include("optimise/Optimise.jl")
using .Optimise
Expand All @@ -35,7 +35,7 @@ end

include("utils.jl")
include("onehot.jl")
include("treelike.jl")
include("functor.jl")

include("layers/stateless.jl")
include("layers/basic.jl")
Expand Down
91 changes: 91 additions & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import Adapt: adapt, adapt_storage
using Zygote: IdSet

functor(x) = (), _ -> x

functor(x::Tuple) = x, y -> y
functor(x::NamedTuple) = x, y -> y

functor(x::AbstractArray) = x, y -> y
functor(x::AbstractArray{<:Number}) = (), _ -> x

function makefunctor(m::Module, T, fs = fieldnames(T))
@eval m begin
Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...)
end
end

function functorm(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
end

macro functor(args...)
functorm(args...)
end

isleaf(x) = functor(x)[1] === ()

function fmap1(f, x)
func, re = functor(x)
re(map(f, func))
end

function fmap(f, x; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
end

trainable(m) = functor(m)[1]

params!(p::Params, x::AbstractArray{<:Real}, seen = IdSet()) = push!(p, x)

function params!(p::Params, x, seen = IdSet())
x in seen && return
push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
end
end

function params(m...)
ps = Params()
params!(ps, m)
return ps
end

# Deprecated stuff
macro treelike(args...)
functorm(args...)
end
mapleaves(f, x) = fmap(f, x)

function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end

# CPU/GPU movement conveniences

cpu(m) = fmap(x -> adapt(Array, x), m)

const gpu_adaptor = if has_cuarrays()
CuArrays.cu
else
identity
end

gpu(x) = fmap(gpu_adaptor, x)

# Precision

adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)

paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)

f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)
11 changes: 5 additions & 6 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex

children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
functor(c::Chain) = c.layers, ls -> Chain(ls...)

applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
Expand Down Expand Up @@ -92,7 +91,7 @@ function Dense(in::Integer, out::Integer, σ = identity;
return Dense(initW(out, in), initb(out), σ)
end

@treelike Dense
@functor Dense

function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ
Expand Down Expand Up @@ -131,7 +130,7 @@ end
Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(initα(in), initβ(in))

@treelike Diagonal
@functor Diagonal

function (a::Diagonal)(x)
α, β = a.α, a.β
Expand Down Expand Up @@ -184,7 +183,7 @@ function Maxout(f, n_alts)
return Maxout(over)
end

@treelike Maxout
@functor Maxout

function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
Expand All @@ -209,7 +208,7 @@ struct SkipConnection
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
end

@treelike SkipConnection
@functor SkipConnection

function (skip::SkipConnection)(input)
#We apply the layers to the input and return the result of the application of the layers and the original input
Expand Down
8 changes: 4 additions & 4 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
Conv(init(k..., ch...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation)

@treelike Conv
@functor Conv

function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
Expand Down Expand Up @@ -102,7 +102,7 @@ ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation)

@treelike ConvTranspose
@functor ConvTranspose

function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
Expand Down Expand Up @@ -180,7 +180,7 @@ function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
)
end

@treelike DepthwiseConv
@functor DepthwiseConv

function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
Expand Down Expand Up @@ -244,7 +244,7 @@ CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
CrossCor(init(k..., ch...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation)

@treelike CrossCor
@functor CrossCor

function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)
Expand Down
26 changes: 10 additions & 16 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
LayerNorm(h::Integer) =
LayerNorm(Diagonal(h))

@treelike LayerNorm
@functor LayerNorm

(a::LayerNorm)(x) = a.diag(normalise(x))

Expand Down Expand Up @@ -134,6 +134,8 @@ BatchNorm(chs::Integer, λ = identity;
BatchNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum)

trainable(bn::BatchNorm) = (bn.β, bn.γ)

function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
Expand Down Expand Up @@ -166,11 +168,7 @@ function (BN::BatchNorm)(x)
end
end

children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum)

mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum)
@functor BatchNorm

function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))")
Expand Down Expand Up @@ -224,6 +222,8 @@ InstanceNorm(chs::Integer, λ = identity;
InstanceNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum)

trainable(in::InstanceNorm) = (in.β, in.γ)

function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) ||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
Expand Down Expand Up @@ -261,11 +261,7 @@ function (in::InstanceNorm)(x)
end
end

children(in::InstanceNorm) =
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum)

mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum)
@functor InstanceNorm

function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))")
Expand Down Expand Up @@ -311,6 +307,8 @@ GroupNorm(chs::Integer, G::Integer, λ = identity;
GroupNorm(G, λ, initβ(chs), initγ(chs),
zeros(G,1), ones(G,1), ϵ, momentum)

trainable(gn::GroupNorm) = (gn.β, gn.γ)

function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
Expand Down Expand Up @@ -360,11 +358,7 @@ function(gn::GroupNorm)(x)
end
end

children(gn::GroupNorm) =
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum)

mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum)
@functor GroupNorm

function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))")
Expand Down
11 changes: 6 additions & 5 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function (m::Recur)(xs...)
return y
end

@treelike Recur cell, init
@functor Recur cell, init

Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")

Expand All @@ -52,7 +52,8 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = hidden(rnn.cell)
"""
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
reset!(m::Recur) = (m.state = m.init)
reset!(m) = foreach(reset!, functor(m)[1])

flip(f, xs) = reverse(f.(reverse(xs)))

Expand All @@ -79,7 +80,7 @@ end

hidden(m::RNNCell) = m.h

@treelike RNNCell
@functor RNNCell

function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
Expand Down Expand Up @@ -127,7 +128,7 @@ end

hidden(m::LSTMCell) = (m.h, m.c)

@treelike LSTMCell
@functor LSTMCell

Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
Expand Down Expand Up @@ -168,7 +169,7 @@ end

hidden(m::GRUCell) = m.h

@treelike GRUCell
@functor GRUCell

Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
Expand Down
Loading

0 comments on commit 797c39d

Please sign in to comment.