From d917e170dbf1cd9d5deda37e78c65425719d04f2 Mon Sep 17 00:00:00 2001 From: Miles Cranmer Date: Fri, 24 Feb 2023 14:48:43 -0500 Subject: [PATCH] Introduce macro to easily create custom layers (#4) Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Co-authored-by: Kyle Daruwalla --- Project.toml | 2 +- src/Fluxperimental.jl | 2 + src/compact.jl | 214 ++++++++++++++++++++++++++++++++++++++++++ test/compact.jl | 184 ++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 402 insertions(+), 1 deletion(-) create mode 100644 src/compact.jl create mode 100644 test/compact.jl diff --git a/Project.toml b/Project.toml index a44f650..5884014 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Fluxperimental" uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658" -version = "0.1.0" +version = "0.1.1" [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/src/Fluxperimental.jl b/src/Fluxperimental.jl index 26026d3..948d5ed 100644 --- a/src/Fluxperimental.jl +++ b/src/Fluxperimental.jl @@ -8,4 +8,6 @@ export Split, Join include("train.jl") export shinkansen! +include("compact.jl") + end # module Fluxperimental diff --git a/src/compact.jl b/src/compact.jl new file mode 100644 index 0000000..d9cbb37 --- /dev/null +++ b/src/compact.jl @@ -0,0 +1,214 @@ +import Flux: _big_show + +""" + @compact(forward::Function; name=nothing, parameters...) + +Creates a layer by specifying some `parameters`, in the form of keywords, +and (usually as a `do` block) a function for the forward pass. +You may think of `@compact` as a specialized `let` block creating local variables +that are trainable in Flux. +Declared variable names may be used within the body of the `forward` function. + +Here is a linear model: + +``` +r = @compact(w = rand(3)) do x + w .* x +end +r([1, 1, 1]) # x is set to [1, 1, 1]. +``` + +Here is a linear model with bias and activation: + +``` +d = @compact(in=5, out=7, W=randn(out, in), b=zeros(out), act=relu) do x + y = W * x + act.(y .+ b) +end +d(ones(5, 10)) # 7×10 Matrix as output. +``` + +Finally, here is a simple MLP: + +``` +using Flux + +n_in = 1 +n_out = 1 +nlayers = 3 + +model = @compact( + w1=Dense(n_in, 128), + w2=[Dense(128, 128) for i=1:nlayers], + w3=Dense(128, n_out), + act=relu +) do x + embed = act(w1(x)) + for w in w2 + embed = act(w(embed)) + end + out = w3(embed) + return out +end + +model(randn(n_in, 32)) # 1×32 Matrix as output. +``` + +We can train this model just like any `Chain`: + +``` +data = [([x], 2x-x^3) for x in -2:0.1f0:2] +optim = Flux.setup(Adam(), model) + +for epoch in 1:1000 + Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim) +end +``` + +You may also specify a `name` for the model, which will +be used instead of the default printout, which gives a verbatim +representation of the code used to construct the model: + +``` +model = @compact(w=rand(3), name="Linear(3 => 1)") do x + sum(w .* x) +end +println(model) # "Linear(3 => 1)" +``` + +This can be useful when using `@compact` to hierarchically construct +complex models to be used inside a `Chain`. +""" +macro compact(fex, kwexs...) + # check input + Meta.isexpr(fex, :(->)) || error("expects a do block") + isempty(kwexs) && error("expects keyword arguments") + all(ex -> Meta.isexpr(ex, :kw), kwexs) || error("expects only keyword argumens") + + # check if user has named layer: + name = findfirst(ex -> ex.args[1] == :name, kwexs) + if name !== nothing && kwexs[name].args[2] !== nothing + length(kwexs) == 1 && error("expects keyword arguments") + name_str = kwexs[name].args[2] + # remove name from kwexs (a tuple) + kwexs = (kwexs[1:name-1]..., kwexs[name+1:end]...) + name = name_str + end + + # make strings + layer = "@compact" + setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs)) + input = join(fex.args[1].args, ", ") + block = string(Base.remove_linenums!(fex).args[2]) + + # edit expressions + vars = map(ex -> ex.args[1], kwexs) + assigns = map(ex -> Expr(:(=), ex.args...), kwexs) + @gensym self + pushfirst!(fex.args[1].args, self) + addprefix!(fex, self, vars) + + # assemble + return esc(quote + let + $(assigns...) + $CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(vars...)) + end + end) +end + +function addprefix!(ex::Expr, self, vars) + for i = 1:length(ex.args) + if ex.args[i] in vars + ex.args[i] = :($self.$(ex.args[i])) + else + addprefix!(ex.args[i], self, vars) + end + end +end +addprefix!(not_ex, self, vars) = nothing + +struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple} + fun::F + name::Union{String,Nothing} + strings::NTuple{3,String} + setup_strings::NT1 + variables::NT2 +end +CompactLayer(f::Function, name::Union{String,Nothing}, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, name, str, setup_str, NamedTuple(kw)) +(m::CompactLayer)(x...) = m.fun(m.variables, x...) +CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro") +Flux.@functor CompactLayer + +Flux._show_children(m::CompactLayer) = m.variables + +function Base.show(io::IO, ::MIME"text/plain", m::CompactLayer) + if get(io, :typeinfo, nothing) === nothing # e.g., top level of REPL + Flux._big_show(io, m) + elseif !get(io, :compact, false) # e.g., printed inside a Vector, but not a matrix + Flux._layer_show(io, m) + else + show(io, m) + end +end + +function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing) + setup_strings = obj.setup_strings + local_name = obj.name + has_explicit_name = local_name !== nothing + if has_explicit_name + if indent != 0 || length(Flux.params(obj)) <= 2 + _just_show_params(io, local_name, obj, indent) + else # indent == 0 + print(io, local_name) + Flux._big_finale(io, obj) + end + else # no name, so print normally + layer, input, block = obj.strings + pre, post = ("(", ")") + println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre) + for k in keys(obj.variables) + v = obj.variables[k] + if Flux._show_leaflike(v) + # If the value is a leaf, just print verbatim what the user wrote: + str = String(k) * " = " * setup_strings[k] + _just_show_params(io, str, v, indent+2) + else + Flux._big_show(io, v, indent+2, String(k)) + end + end + if indent == 0 # i.e. this is the outermost container + print(io, rpad(post, 1)) + else + print(io, " "^indent, post) + end + + input != "" && print(io, " do ", input) + if block != "" + block_to_print = block[6:end] + # Increase indentation of block according to `indent`: + block_to_print = replace(block_to_print, r"\n" => "\n" * " "^(indent)) + print(io, " ", block_to_print) + end + if indent == 0 + Flux._big_finale(io, obj) + else + println(io, ",") + end + end +end + +# Modified from src/layers/show.jl +function _just_show_params(io::IO, str::String, layer, indent::Int=0) + print(io, " "^indent, str, indent==0 ? "" : ",") + if !isempty(Flux.params(layer)) + print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) + printstyled(io, "# ", Flux.underscorise(sum(length, Flux.params(layer))), " parameters"; color=:light_black) + nonparam = Flux._childarray_sum(length, layer) - sum(length, Flux.params(layer)) + if nonparam > 0 + printstyled(io, ", plus ", Flux.underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black) + end + Flux._nan_show(io, Flux.params(layer)) + end + indent==0 || println(io) +end diff --git a/test/compact.jl b/test/compact.jl new file mode 100644 index 0000000..06ebe8e --- /dev/null +++ b/test/compact.jl @@ -0,0 +1,184 @@ +import Fluxperimental: @compact + +# Strip both strings of spaces, and then test: +function similar_strings(s1, s2) + s1 = replace(s1, r"\s" => "") + s2 = replace(s2, r"\s" => "") + + # We also remove any instances of, e.g., + # 17.057 KiB (or any other number) + # because this depends on indentation in this file. + s1 = replace(s1, r"\d+\.\d+KiB" => "") + s2 = replace(s2, r"\d+\.\d+KiB" => "") + + # Display any differences: + if s1 != s2 + println(stderr, "s1: ", s1) + println(stderr, "s2: ", s2) + end + return s1 == s2 +end + +function get_model_string(model) + io = IOBuffer() + show(io, MIME"text/plain"(), model) + String(take!(io)) +end + +@testset "@compact" begin + + r = @compact(w = [1, 5, 10]) do x + sum(w .* x) + end + @test Flux.params(r) == Flux.Params([[1, 5, 10]]) + @test r([1, 1, 1]) == 1 + 5 + 10 + @test r([1, 2, 3]) == 1 + 2 * 5 + 3 * 10 + @test r(ones(3, 3)) == 3 * (1 + 5 + 10) + + # Test gradients: + @test gradient(r, [1, 1, 1])[1] == [1, 5, 10] + + d = @compact(in = 5, out = 7, W = randn(out, in), b = zeros(out), act = relu) do x + y = W * x + act.(y .+ b) + end + + @test size.(Flux.params(d)) == [(7, 5), (7,)] + + @test size(d(ones(5, 10))) == (7, 10) + @test all(d(randn(5, 10)) .>= 0) + + # Test gradients: + y, ∇ = Flux.withgradient(Flux.params(d)) do + input = randn(5, 32) + desired_output = randn(7, 32) + prediction = d(input) + sum((prediction - desired_output) .^ 2) + end + @test typeof(y) == Float64 + grads = ∇.grads + @test typeof(grads) <: IdDict + @test length(grads) == 3 + @test Set(size.(values(grads))) == Set([(7, 5), (), (7,)]) + + + # MLP: + n_in = 1 + n_out = 1 + nlayers = 3 + + model = @compact( + w1 = Dense(n_in, 128), + w2 = [Dense(128, 128) for i = 1:nlayers], + w3 = Dense(128, n_out), + act = relu + ) do x + embed = act(w1(x)) + for w in w2 + embed = act(w(embed)) + end + out = w3(embed) + return out + end + + @test size.(Flux.params(model)) == [ + (128, 1), + (128,), + (128, 128), + (128,), + (128, 128), + (128,), + (128, 128), + (128,), + (1, 128), + (1,), + ] + @test size(model(randn(n_in, 32))) == (1, 32) + + # Test string representations: + model = @compact(w=Dense(32 => 32)) do x, y + tmp = sum(w(x)) + return tmp + y + end + expected_string = """@compact( + w = Dense(32=>32), #1_056 parameters + ) do x, y + tmp = sum(w(x)) + return tmp + y + end""" + @test similar_strings(get_model_string(model), expected_string) + + # Custom naming: + model = @compact(w=Dense(32, 32), name="Linear(...)") do x, y + tmp = sum(w(x)) + return tmp + y + end + expected_string = "Linear(...) # 1_056 parameters" + @test similar_strings(get_model_string(model), expected_string) + + # Hierarchical models should work too: + model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x + w2(w1(x)) + end + model2 = @compact(w1=model1, w2=Dense(32=>32, relu)) do x + w2(w1(x)) + end + expected_string = """@compact( + w1 = @compact( + w1 = Dense(32 => 32, relu), # 1_056 parameters + w2 = Dense(32 => 32, relu), # 1_056 parameters + ) do x + w2(w1(x)) + end, + w2 = Dense(32 => 32, relu), # 1_056 parameters + ) do x + w2(w1(x)) + end # Total: 6 arrays, 3_168 parameters, 13.271 KiB.""" + @test similar_strings(get_model_string(model2), expected_string) + + # With array params: + model = @compact(x=randn(32), w=Dense(32=>32)) do s + w(x .* s) + end + expected_string = """@compact( + x = randn(32), # 32 parameters + w = Dense(32 => 32), # 1_056 parameters + ) do s + w(x .* s) + end # Total: 3 arrays, 1_088 parameters, 4.734 KiB.""" + @test similar_strings(get_model_string(model), expected_string) + + # Hierarchy with inner model named: + model = @compact( + w1=@compact(w1=randn(32, 32), name="Model(32)") do x + w1 * x + end, + w2=randn(32, 32), + w3=randn(32), + ) do x + w2 * w1(x) + end + expected_string = """@compact( + Model(32), # 1_024 parameters + w2 = randn(32, 32), # 1_024 parameters + w3 = randn(32), # 32 parameters + ) do x + w2 * w1(x) + end # Total: 3 arrays, 2_080 parameters, 17.089 KiB.""" + @test similar_strings(get_model_string(model), expected_string) + + # Hierarchy with outer model named: + model = @compact( + w1=@compact(w1=randn(32, 32)) do x + w1 * x + end, + w2=randn(32, 32), + w3=randn(32), + name="Model(32)" + ) do x + w2 * w1(x) + end + expected_string = """Model(32) # Total: 3 arrays, 2_080 parameters, 17.057KiB.""" + @test similar_strings(get_model_string(model), expected_string) + +end diff --git a/test/runtests.jl b/test/runtests.jl index 979b386..7c804c4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,4 +3,5 @@ using Flux, Fluxperimental @testset "Fluxperimental.jl" begin include("split_join.jl") + include("compact.jl") end