Skip to content

Commit

Permalink
Add @autostruct macro (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Oct 29, 2024
1 parent 4fe8b4e commit 197d494
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.10'
- '1'
- 'nightly'
os:
Expand Down
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Compat = "4"
Flux = "0.13.7, 0.14"
NNlib = "0.8.10, 0.9"
Optimisers = "0.2.10, 0.3"
Flux = "0.14.23"
NNlib = "0.9"
Optimisers = "0.3"
ProgressMeter = "1.7.2"
Zygote = "0.6.49"
julia = "1.6"
julia = "1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Fluxperimental.jl

[![][action-img]][action-url]
[![][coverage-img]][coverage-url]
[![][coverage-img]][coverage-url]

[action-img]: https://github.com/FluxML/Fluxperimental.jl/workflows/CI/badge.svg
[action-url]: https://github.com/FluxML/Fluxperimental.jl/actions
Expand All @@ -12,14 +12,14 @@
[coverage-url]: https://codecov.io/gh/FluxML/Fluxperimental.jl


The repository contains experimental features for [Flux.jl](https://github.com/FluxML/Flux.jl).
This contains experimental features for [Flux.jl](https://github.com/FluxML/Flux.jl).
It needs to be loaded in addition to the main package:

```julia
using Flux, Fluxperimental
```

As an experiment, it only has discussion pages, not issues. Actual bugs reports are welcome,
As an experiment, this repository only has discussion pages, not issues. Actual bugs reports are welcome,
as are comments that you think something is a great idea, or better ways achive the same goal,
or nice examples showing how it works.

Expand All @@ -33,7 +33,12 @@ As will any features which migrate to Flux itself.

## Current Features

There are no formal documentation pages, but these links to the source will show you docstrings
(which are also available at the REPL prompt).

* Layers [`Split` and `Join`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/split_join.jl)
* More advanced [`train!` function](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/train.jl)
* Macro for [making custom layers](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/compact.jl) quickly
* *Two* macros for making custom layers quickly:
[`@compact(kw...) do ...`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/compact.jl), and
[`@autostruct function Mine(d) ...`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/autostruct.jl).
* Experimental [`apply(c::Chain, x)`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/chain.jl) interface
3 changes: 3 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ include("compact.jl")
include("noshow.jl")
export NoShow

include("autostruct.jl")
export @autostruct

include("new_recur.jl")

end # module Fluxperimental
157 changes: 157 additions & 0 deletions src/autostruct.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@

"""
@autostruct function MyLayer(d); ...; MyLayer(f1, f2, ...); end
This is a macro for easily defining new layers.
Recall that Flux layer is a callable `struct` which may contain parameter arrays.
Usually, the steps to define a new one are:
1. Define a `struct MyLayer` with the desired fields,
and tell Flux to look inside with `@layer MyLayer` (or on earlier versions, `@functor`).
2. Define a constructor function like `MyLayer(d::Int)`,
which initialises the parameters (say to `randn32(d, d)`)
and returns an instance of the `struct`, some `m::MyLayer`.
3. Define the forward pass, by making the struct callable: `(m::MyLayer)(x) = ...`
Given the function in step 2, this macro handles step 1. You still do step 3.
If you change the name or the fields, then the `struct` definition is automatically replaced.
This works because this definition uses an auto-generated name, which is `== MyLayer`.
(But existing instances of the old `struct` are not changed in any way!)
Writing `@autostruct :expand function MyLayer(d)` will use `@layer :expand MyLayer`,
and result in container-style pretty-printing.
Note that the `struct` will sometimes have extra fields containing `nothing`,
to ensure that your constructor function cannot be ambiguous with the default constructor.
In the example below, `@autostruct :expand function MyModel(d, d2=d)` will show this behaviour.
## Example
```julia
@autostruct function MyModel(d::Int)
alpha, beta = [Dense(d=>d, tanh) for _ in 1:2] # arbitrary code here, not just keyword-like
beta.bias[:] .= 1/d
return MyModel(alpha, beta) # this must be very simple, no = signs allowed (return optional)
end
function (m::MyModel)(x) # forward pass looks just like a normal struct
y = m.alpha(x)
z = m.beta(y)
(x .+ y .+ z)./3
end
Flux.trainable(m::MyModel) = (; m.alpha) # if necessary, restrict which fields are trainable
Base.show(io::IO, m::MyModel) = # if desired, replace default printing "MyModel(...)"
print(io, "MyModel(", size(m.alpha.weight, 1), ")")
MyModel(2) isa MyModel # true
```
For comparison, the use of `@compact` to do much the same thing looks like this -- shorter,
but further from being ordinary Julia code.
```julia
function MyModel2(d::Int)
alpha, beta = [Dense(d=>d, tanh) for _ in 1:2]
beta.bias[:] .= 1/d
@compact(; alpha, beta) do x
y = alpha(x)
z = beta(y)
(x .+ y .+ z)./3
end
end
MyModel2(2) isa Fluxperimental.CompactLayer # no easy struct type
```
"""
macro autostruct(ex)
esc(_autostruct(ex))
end

macro autostruct(ex1, ex2)
(ex1 isa QuoteNode && ex1.value == :expand) || throw("Expected either `@autostruct function` or `@autostruct :expand function`")
esc(_autostruct(ex2; expand=true))
end

const DEFINE = Dict{UInt, Tuple}()
const NOFIELD = :_nothing # perhaps better not gensym(:nothing), to be the same after re-starting, as field names survive in Flux.state(model)

function _autostruct(expr; expand::Bool=false)
# Check first & last line of the input expression:
Meta.isexpr(expr, :function) || throw("Expected a function definition, like `@autostruct function MyStruct(...); ...`")
fun = expr.args[1].args[1]
ret = expr.args[2].args[end]
if Meta.isexpr(ret, :return)
ret = only(ret.args)
end
Meta.isexpr(ret, :call) || throw("Last line of `@autostruct function $fun` must return `$fun(field1, field2, ...)`")
ret.args[1] === fun || throw("Last line of `@autostruct function $fun` must return `$fun(field1, field2, ...)`")
for ex in ret.args
ex isa Symbol || throw("Last line of `@autostruct function $fun` must return `$fun(field1, field2, ...)` with only symbols, got $ex")
contains(string(ex), string(NOFIELD)) && throw("Field names containing `$NOFIELD` are reserved by `@autostruct`")
end

# Ensure that there are more fields than input arguments:
narg = _count_args(expr.args[1])
nret = _count_args(ret)
nothings = Int[]
for i in 1:(narg-nret+1)
sy = Symbol(NOFIELD, :_, i)
push!(ret.args, sy)
push!(nothings, length(ret.args)) # index for later use
end

# If the last line is new, construct struct definition:
name, defex = get!(DEFINE, hash(ret, UInt(expand))) do
name = gensym(fun)
fields = map(enumerate(ret.args[2:end])) do (i, field)
if occursin(string(NOFIELD), string(field))
:($field::Nothing)
else
type = Symbol("T#", i)
:($field::$type)
end
end
types = filter(T -> T != :Nothing, map(f -> f.args[2], fields))
layer = if !expand
:($Flux.@layer $name)
else
str = "$fun("
quote
$Flux.@layer :expand $name
Flux._show_pre_post(::$name) = $str, ")" # needs https://github.com/FluxML/Flux.jl/pull/2344
end
end
str = "$fun(...)"
ex = quote
struct $name{$(types...)}
$(fields...)
end
$layer
$Base.show(io::IO, _::$name) = $print(io, $str)
$fun = $name
end
(name, ex)
end

# Change first line to use the struct's name:
expr.args[1].args[1] = name
# Change last line to use nothing:
for j in nothings
ret.args[j] = nothing
end
quote
$(defex.args...) # struct definition
$expr # constructor function
end
end

function _count_args(ex::Expr)
@assert Meta.isexpr(ex, :call)
count(ex.args[2:end]) do arg
# Three options for f(a, b::Int, c=3), but not keywords
arg isa Symbol || Meta.isexpr(arg, [:(::), :kw])
end
end
47 changes: 47 additions & 0 deletions test/autostruct.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
Fluxperimental.DEFINE |> empty!

@autostruct function New1(a)
A = Dense(a => 2a)
New1(A)
end

(m::New1)(x) = one.(m.A(x))

@testset "simple case" begin
m1 = New1(3)
@test m1 isa New1
@test Flux.state(m1).A.bias == zeros(Float32, 6)
@test m1([1,2,3]) == ones(Float32, 6)
end

id1 = string(New1) # something like "var\"##New1#265\""

@autostruct function New1(a) # re-definition of constructor, same struct!
A = Dense(2a => a, tanh)
New1(A)
end

(m::New1)(x) = one.(m.A(x)) .+ 2 # re-definition of forward pass, same struct!

@testset "re-defined" begin
@test string(New1) == id1
m2 = New1(2)
@test m2 isa New1
@test Flux.state(m2).A.bias == zeros(Float32, 2)
@test m2([1,2,3,4]) == [3f0, 3f0]
end

@autostruct :expand function New1(a, b=3) # new struct, both for :expand and for b argument
A = Dense(a => b)
New1(A)
end

@testset "new defn" begin
@test string(New1) != id1
m3 = New1(3)
@test m3 isa New1
@test Flux.state(m3).A.bias == zeros(Float32, 3)
# pretty printing
@test contains(repr("text/plain", m3), "New1(\n")
@test contains(repr("text/plain", m3), "Dense(3 => 3)")
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using Flux, Fluxperimental
include("compact.jl")
include("noshow.jl")

include("autostruct.jl")

include("new_recur.jl")

end

2 comments on commit 197d494

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118328

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 197d49454b2e4699b043838fe8c00b4a4a8286fe
git push origin v0.2.0

Please sign in to comment.