-
-
Notifications
You must be signed in to change notification settings - Fork 608
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
865: Functor r=MikeInnes a=MikeInnes This refactors our current `@treelike` infrastructure. It somewhat formalises what we're doing around the idea of a Flux model as a functor, i.e. something that can be mapped over. This is much more flexible than what we had before, and avoids some issues. It allows layers to have state that isn't mappable; it allows for dispatch when walking the tree, which means layers like `BatchNorm` can have non-trainable parameters; and it also allows for zipped mapping like `fmap(+, xs, ys)`, which isn't implemented yet but will be useful for the new optimisers work. The main downside is that the term `functor` has been previously used in the Julia community as a malapropism for "thing that behaves like a function"; but hopefully this can start to reduce that usage. Co-authored-by: Mike Innes <mike.j.innes@gmail.com>
- Loading branch information
Showing
12 changed files
with
131 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import Adapt: adapt, adapt_storage | ||
using Zygote: IdSet | ||
|
||
functor(x) = (), _ -> x | ||
|
||
functor(x::Tuple) = x, y -> y | ||
functor(x::NamedTuple) = x, y -> y | ||
|
||
functor(x::AbstractArray) = x, y -> y | ||
functor(x::AbstractArray{<:Number}) = (), _ -> x | ||
|
||
function makefunctor(m::Module, T, fs = fieldnames(T)) | ||
@eval m begin | ||
Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...) | ||
end | ||
end | ||
|
||
function functorm(T, fs = nothing) | ||
fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)") | ||
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] | ||
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...))) | ||
end | ||
|
||
macro functor(args...) | ||
functorm(args...) | ||
end | ||
|
||
isleaf(x) = functor(x)[1] === () | ||
|
||
function fmap1(f, x) | ||
func, re = functor(x) | ||
re(map(f, func)) | ||
end | ||
|
||
function fmap(f, x; cache = IdDict()) | ||
haskey(cache, x) && return cache[x] | ||
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x) | ||
end | ||
|
||
trainable(m) = functor(m)[1] | ||
|
||
params!(p::Params, x::AbstractArray{<:Real}, seen = IdSet()) = push!(p, x) | ||
|
||
function params!(p::Params, x, seen = IdSet()) | ||
x in seen && return | ||
push!(seen, x) | ||
for child in trainable(x) | ||
params!(p, child, seen) | ||
end | ||
end | ||
|
||
function params(m...) | ||
ps = Params() | ||
params!(ps, m) | ||
return ps | ||
end | ||
|
||
# Deprecated stuff | ||
macro treelike(args...) | ||
functorm(args...) | ||
end | ||
mapleaves(f, x) = fmap(f, x) | ||
|
||
function loadparams!(m, xs) | ||
for (p, x) in zip(params(m), xs) | ||
size(p) == size(x) || | ||
error("Expected param size $(size(p)), got $(size(x))") | ||
copyto!(p, x) | ||
end | ||
end | ||
|
||
# CPU/GPU movement conveniences | ||
|
||
cpu(m) = fmap(x -> adapt(Array, x), m) | ||
|
||
const gpu_adaptor = if has_cuarrays() | ||
CuArrays.cu | ||
else | ||
identity | ||
end | ||
|
||
gpu(x) = fmap(gpu_adaptor, x) | ||
|
||
# Precision | ||
|
||
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs) | ||
|
||
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m) | ||
|
||
f32(m) = paramtype(Float32, m) | ||
f64(m) = paramtype(Float64, m) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.