Skip to content

Commit

Permalink
Introduce macro to easily create custom layers (#4)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
  • Loading branch information
3 people authored Feb 24, 2023
1 parent 8650d54 commit d917e17
Show file tree
Hide file tree
Showing 5 changed files with 402 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Fluxperimental"
uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658"
version = "0.1.0"
version = "0.1.1"

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down
2 changes: 2 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ export Split, Join
include("train.jl")
export shinkansen!

include("compact.jl")

end # module Fluxperimental
214 changes: 214 additions & 0 deletions src/compact.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import Flux: _big_show

"""
@compact(forward::Function; name=nothing, parameters...)
Creates a layer by specifying some `parameters`, in the form of keywords,
and (usually as a `do` block) a function for the forward pass.
You may think of `@compact` as a specialized `let` block creating local variables
that are trainable in Flux.
Declared variable names may be used within the body of the `forward` function.
Here is a linear model:
```
r = @compact(w = rand(3)) do x
w .* x
end
r([1, 1, 1]) # x is set to [1, 1, 1].
```
Here is a linear model with bias and activation:
```
d = @compact(in=5, out=7, W=randn(out, in), b=zeros(out), act=relu) do x
y = W * x
act.(y .+ b)
end
d(ones(5, 10)) # 7×10 Matrix as output.
```
Finally, here is a simple MLP:
```
using Flux
n_in = 1
n_out = 1
nlayers = 3
model = @compact(
w1=Dense(n_in, 128),
w2=[Dense(128, 128) for i=1:nlayers],
w3=Dense(128, n_out),
act=relu
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
end
out = w3(embed)
return out
end
model(randn(n_in, 32)) # 1×32 Matrix as output.
```
We can train this model just like any `Chain`:
```
data = [([x], 2x-x^3) for x in -2:0.1f0:2]
optim = Flux.setup(Adam(), model)
for epoch in 1:1000
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
end
```
You may also specify a `name` for the model, which will
be used instead of the default printout, which gives a verbatim
representation of the code used to construct the model:
```
model = @compact(w=rand(3), name="Linear(3 => 1)") do x
sum(w .* x)
end
println(model) # "Linear(3 => 1)"
```
This can be useful when using `@compact` to hierarchically construct
complex models to be used inside a `Chain`.
"""
macro compact(fex, kwexs...)
# check input
Meta.isexpr(fex, :(->)) || error("expects a do block")
isempty(kwexs) && error("expects keyword arguments")
all(ex -> Meta.isexpr(ex, :kw), kwexs) || error("expects only keyword argumens")

# check if user has named layer:
name = findfirst(ex -> ex.args[1] == :name, kwexs)
if name !== nothing && kwexs[name].args[2] !== nothing
length(kwexs) == 1 && error("expects keyword arguments")
name_str = kwexs[name].args[2]
# remove name from kwexs (a tuple)
kwexs = (kwexs[1:name-1]..., kwexs[name+1:end]...)
name = name_str
end

# make strings
layer = "@compact"
setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs))
input = join(fex.args[1].args, ", ")
block = string(Base.remove_linenums!(fex).args[2])

# edit expressions
vars = map(ex -> ex.args[1], kwexs)
assigns = map(ex -> Expr(:(=), ex.args...), kwexs)
@gensym self
pushfirst!(fex.args[1].args, self)
addprefix!(fex, self, vars)

# assemble
return esc(quote
let
$(assigns...)
$CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(vars...))
end
end)
end

function addprefix!(ex::Expr, self, vars)
for i = 1:length(ex.args)
if ex.args[i] in vars
ex.args[i] = :($self.$(ex.args[i]))
else
addprefix!(ex.args[i], self, vars)
end
end
end
addprefix!(not_ex, self, vars) = nothing

struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple}
fun::F
name::Union{String,Nothing}
strings::NTuple{3,String}
setup_strings::NT1
variables::NT2
end
CompactLayer(f::Function, name::Union{String,Nothing}, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, name, str, setup_str, NamedTuple(kw))
(m::CompactLayer)(x...) = m.fun(m.variables, x...)
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro")
Flux.@functor CompactLayer

Flux._show_children(m::CompactLayer) = m.variables

function Base.show(io::IO, ::MIME"text/plain", m::CompactLayer)
if get(io, :typeinfo, nothing) === nothing # e.g., top level of REPL
Flux._big_show(io, m)
elseif !get(io, :compact, false) # e.g., printed inside a Vector, but not a matrix
Flux._layer_show(io, m)
else
show(io, m)
end
end

function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
setup_strings = obj.setup_strings
local_name = obj.name
has_explicit_name = local_name !== nothing
if has_explicit_name
if indent != 0 || length(Flux.params(obj)) <= 2
_just_show_params(io, local_name, obj, indent)
else # indent == 0
print(io, local_name)
Flux._big_finale(io, obj)
end
else # no name, so print normally
layer, input, block = obj.strings
pre, post = ("(", ")")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre)
for k in keys(obj.variables)
v = obj.variables[k]
if Flux._show_leaflike(v)
# If the value is a leaf, just print verbatim what the user wrote:
str = String(k) * " = " * setup_strings[k]
_just_show_params(io, str, v, indent+2)
else
Flux._big_show(io, v, indent+2, String(k))
end
end
if indent == 0 # i.e. this is the outermost container
print(io, rpad(post, 1))
else
print(io, " "^indent, post)
end

input != "" && print(io, " do ", input)
if block != ""
block_to_print = block[6:end]
# Increase indentation of block according to `indent`:
block_to_print = replace(block_to_print, r"\n" => "\n" * " "^(indent))
print(io, " ", block_to_print)
end
if indent == 0
Flux._big_finale(io, obj)
else
println(io, ",")
end
end
end

# Modified from src/layers/show.jl
function _just_show_params(io::IO, str::String, layer, indent::Int=0)
print(io, " "^indent, str, indent==0 ? "" : ",")
if !isempty(Flux.params(layer))
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
printstyled(io, "# ", Flux.underscorise(sum(length, Flux.params(layer))), " parameters"; color=:light_black)
nonparam = Flux._childarray_sum(length, layer) - sum(length, Flux.params(layer))
if nonparam > 0
printstyled(io, ", plus ", Flux.underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black)
end
Flux._nan_show(io, Flux.params(layer))
end
indent==0 || println(io)
end
Loading

3 comments on commit d917e17

@MilesCranmer
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can one of you @darsnack @mcabbott trigger JuliaRegistrator?

@ToucheSir
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/79123

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.1 -m "<description of version>" d917e170dbf1cd9d5deda37e78c65425719d04f2
git push origin v0.1.1

Please sign in to comment.