Skip to content

Commit

Permalink
forgotten changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 1, 2023
1 parent 94ad9dd commit a250c51
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 41 deletions.
1 change: 1 addition & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export shinkansen!
include("chain.jl")

include("compact.jl")
export @compact

include("new_recur.jl")

Expand Down
88 changes: 47 additions & 41 deletions src/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,66 @@ import Flux: _big_show
@compact(forward::Function; name=nothing, parameters...)
Creates a layer by specifying some `parameters`, in the form of keywords,
and (usually as a `do` block) a function for the forward pass.
and a function for the forward pass (often as a `do` block).
You may think of `@compact` as a specialized `let` block creating local variables
that are trainable in Flux.
Declared variable names may be used within the body of the `forward` function.
Here is a linear model:
# Examples
Here is a linear model, equivalent to `Flux.Scale`:
```
r = @compact(w = rand(3)) do x
w .* x
end
r([1, 1, 1]) # x is set to [1, 1, 1].
using Flux, Fluxperimental
w = rand(3)
sc = @compact(x -> x .* w; w)
sc([1 10 100]) # 3×3 Matrix as output.
ans ≈ Flux.Scale(w)([1 10 100]) # equivalent Flux layer
```
Here is a linear model with bias and activation:
Here is a linear model with bias and activation, equivalent to Flux's `Dense` layer.
The forward pass function is now written as a do block, instead of `x -> begin y = W * x; ...`
```
d_in = 5
d_in = 3
d_out = 7
d = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
layer = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
y = W * x
act.(y .+ b)
end
d(ones(5, 10)) # 7×10 Matrix as output.
d([1,2,3,4,5]) ≈ Dense(d.variables.W, zeros(7), relu)([1,2,3,4,5]) # Equivalent to a dense layer
den = Dense(layer.variables.W, zeros(7), relu)([1,2,3]) # equivalent Flux layer
layer(ones(3, 10)) ≈ layer(ones(3, 10)) # 7×10 Matrix as output.
```
```
Finally, here is a simple MLP:
Finally, here is a simple MLP, equivalent to a `Chain` with 5 `Dense` layers:
```
using Flux
n_in = 1
n_out = 1
d_in = 1
nlayers = 3
model = @compact(
w1=Dense(n_in, 128),
w2=[Dense(128, 128) for i=1:nlayers],
w3=Dense(128, n_out),
act=relu
lay1 = Dense(d_in => 64),
lay234 = [Dense(64 => 64) for i=1:nlayers],
wlast = rand32(64),
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
y = tanh.(lay1(x))
for lay in lay234
y = relu.(lay(y))
end
out = w3(embed)
return out
return wlast' * y
end
model(randn(n_in, 32)) # 1×32 Matrix as output.
model(randn(Float32, d_in, 8)) # 1×8 array as output.
```
We can train this model just like any `Chain`:
We can train this model just like any `Chain`, for example:
```
data = [([x], 2x-x^3) for x in -2:0.1f0:2]
data = [([x], [2x-x^3]) for x in -2:0.1f0:2]
optim = Flux.setup(Adam(), model)
for epoch in 1:1000
Expand All @@ -70,19 +72,23 @@ end
```
"""
macro compact(_exs...)
_compact(_exs...) |> esc
end

function _compact(_exs...)
# check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
isempty(_exs) && error("expects at least two expressions: a function and at least one keyword")
isempty(_exs) && error("@compact expects at least two expressions: a function and at least one keyword")
if Meta.isexpr(_exs[1], :parameters)
length(_exs) >= 2 || error("expects an anonymous function")
length(_exs) >= 2 || error("@compact expects an anonymous function")
fex = _exs[2]
_kwexs = (_exs[1], _exs[3:end]...)
else
fex = _exs[1]
_kwexs = _exs[2:end]
end
Meta.isexpr(fex, :(->)) || error("expects an anonymous function")
isempty(_kwexs) && error("expects keyword arguments")
all(ex -> Meta.isexpr(ex, (:kw,:(=),:parameters)), _kwexs) || error("expects only keyword arguments")
Meta.isexpr(fex, :(->)) || error("@compact expects an anonymous function")
isempty(_kwexs) && error("@compact expects keyword arguments")
all(ex -> Meta.isexpr(ex, (:kw,:(=),:parameters)), _kwexs) || error("@compact expects only keyword arguments")

# process keyword arguments
if Meta.isexpr(_kwexs[1], :parameters) # handle keyword arguments provided after semicolon
Expand All @@ -100,20 +106,20 @@ macro compact(_exs...)
fex_args = fex.args[1]
isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ")
catch e
@warn "Function stringifying does not yet handle all cases. Falling back to empty string for input arguments"
""
@warn """@compact's function stringifying does not yet handle all cases. Falling back to "?" """ maxlog=1
"?"
end
block = string(Base.remove_linenums!(fex).args[2])
block = string(Base.remove_linenums!(fex).args[2]) # TODO make this remove macro comments

# edit expressions
vars = map(ex -> ex.args[1], kwexs)
fex = supportself(fex, vars)
fex = _supportself(fex, vars)

# assemble
return esc(:($CompactLayer($fex, ($input, $block); $(kwexs...))))
return :($CompactLayer($fex, ($input, $block); $(kwexs...)))
end

function supportself(fex::Expr, vars)
function _supportself(fex::Expr, vars)
@gensym self
@gensym curried_f
# To avoid having to manipulate fex's arguments and body explicitly, we form a curried function first
Expand Down Expand Up @@ -173,7 +179,7 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
print(io, " "^indent, post)
end

input != "" && print(io, " do ", input)
print(io, " do ", input)
if block != ""
block_to_print = block[6:end]
# Increase indentation of block according to `indent`:
Expand Down

0 comments on commit a250c51

Please sign in to comment.