Skip to content

Commit

Permalink
Change @autostruct to allow type restrictions (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Nov 1, 2024
1 parent 197d494 commit a850e1d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Fluxperimental"
uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658"
version = "0.2.0"
version = "0.2.1"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
94 changes: 57 additions & 37 deletions src/autostruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,14 @@ Usually, the steps to define a new one are:
Given the function in step 2, this macro handles step 1. You still do step 3.
If you change the name or the fields, then the `struct` definition is automatically replaced.
If you change the name or types of the fields, then the `struct` definition is automatically replaced.
This works because this definition uses an auto-generated name, which is `== MyLayer`.
(But existing instances of the old `struct` are not changed in any way!)
Writing `@autostruct :expand function MyLayer(d)` will use `@layer :expand MyLayer`,
and result in container-style pretty-printing.
Note that the `struct` will sometimes have extra fields containing `nothing`,
to ensure that your constructor function cannot be ambiguous with the default constructor.
In the example below, `@autostruct :expand function MyModel(d, d2=d)` will show this behaviour.
## Example
## Examples
```julia
@autostruct function MyModel(d::Int)
Expand All @@ -49,6 +45,43 @@ Base.show(io::IO, m::MyModel) = # if desired, replace default printing "MyModel
MyModel(2) isa MyModel # true
```
The `struct` defined by the macro here is something like this:
```julia
struct MyModel001{T1, T2}
alpha::T1
beta::T2
end
```
Since this can hold any objects, even `MyModel("hello", "world")`.
As you can see by looking `methods(MyModel)`, there should never be an ambiguity
between the `struct`'s own constructor, and your `MyModel(d::Int)`.
You can also restrict the types allowed in the struct:
```
@autostruct :expand function MyOtherModel(d1, d2, act=identity)
gamma = Embedding(128 => d1)
delta = Dense(d1 => d2, act)
MyOtherModel(gamma::Embedding, delta::Dense) # struct will only hold these types
end
(m::MyOtherModel)(x) = softmax(m.delta(m.gamma(x))) # forward pass
methods(MyOtherModel) # will show 3 methods
```
This creates a struct like this:
```julia
struct MyOtherModel001{T1 <: Embedding, T2 <: Dense}
gamma::T1
delta::T2
end
```
## Compared to `@compact`
For comparison, the use of `@compact` to do much the same thing looks like this -- shorter,
but further from being ordinary Julia code.
Expand All @@ -64,6 +97,11 @@ function MyModel2(d::Int)
end
MyModel2(2) isa Fluxperimental.CompactLayer # no easy struct type
MyOtherModel2(d1, d2, act=identity) =
@compact(gamma = Embedding(128 => d1), delta=Dense(d1 => d2, act)) do x
softmax(delta(gamma(x)))
end
```
"""
macro autostruct(ex)
Expand All @@ -76,7 +114,6 @@ macro autostruct(ex1, ex2)
end

const DEFINE = Dict{UInt, Tuple}()
const NOFIELD = :_nothing # perhaps better not gensym(:nothing), to be the same after re-starting, as field names survive in Flux.state(model)

function _autostruct(expr; expand::Bool=false)
# Check first & last line of the input expression:
Expand All @@ -89,32 +126,27 @@ function _autostruct(expr; expand::Bool=false)
Meta.isexpr(ret, :call) || throw("Last line of `@autostruct function $fun` must return `$fun(field1, field2, ...)`")
ret.args[1] === fun || throw("Last line of `@autostruct function $fun` must return `$fun(field1, field2, ...)`")
for ex in ret.args
ex isa Symbol || throw("Last line of `@autostruct function $fun` must return `$fun(field1, field2, ...)` with only symbols, got $ex")
contains(string(ex), string(NOFIELD)) && throw("Field names containing `$NOFIELD` are reserved by `@autostruct`")
end

# Ensure that there are more fields than input arguments:
narg = _count_args(expr.args[1])
nret = _count_args(ret)
nothings = Int[]
for i in 1:(narg-nret+1)
sy = Symbol(NOFIELD, :_, i)
push!(ret.args, sy)
push!(nothings, length(ret.args)) # index for later use
ex isa Symbol && continue
Meta.isexpr(ex, :(::)) && continue
throw("Last line of `@autostruct function $fun` must return `$fun(field1, field2, ...)` or `$fun(field1::T1, field2::T2, ...)`, but got $ex")
end

# If the last line is new, construct struct definition:
name, defex = get!(DEFINE, hash(ret, UInt(expand))) do
name = gensym(fun)
fields = map(enumerate(ret.args[2:end])) do (i, field)
if occursin(string(NOFIELD), string(field))
:($field::Nothing)
fields = map(enumerate(ret.args[2:end])) do (i, ex)
field = ex isa Symbol ? ex : ex.args[1] # we allow `return MyModel(alpha, beta::Chain)`
type = Symbol("T#", i)
@show ex field type
:($field::$type)
end
types = map(fields, ret.args[2:end]) do ft, ex
if ex isa Symbol # then no type spec on return line
ft.args[2]
else
type = Symbol("T#", i)
:($field::$type)
Expr(:(<:), ft.args[2], ex.args[2])
end
end
types = filter(T -> T != :Nothing, map(f -> f.args[2], fields))
layer = if !expand
:($Flux.@layer $name)
else
Expand All @@ -138,20 +170,8 @@ function _autostruct(expr; expand::Bool=false)

# Change first line to use the struct's name:
expr.args[1].args[1] = name
# Change last line to use nothing:
for j in nothings
ret.args[j] = nothing
end
quote
$(defex.args...) # struct definition
$expr # constructor function
end
end

function _count_args(ex::Expr)
@assert Meta.isexpr(ex, :call)
count(ex.args[2:end]) do arg
# Three options for f(a, b::Int, c=3), but not keywords
arg isa Symbol || Meta.isexpr(arg, [:(::), :kw])
end
end
6 changes: 3 additions & 3 deletions test/autostruct.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Fluxperimental.DEFINE |> empty!

@autostruct function New1(a)
@autostruct function New1(a::Int)
A = Dense(a => 2a)
New1(A)
end
Expand All @@ -16,7 +16,7 @@ end

id1 = string(New1) # something like "var\"##New1#265\""

@autostruct function New1(a) # re-definition of constructor, same struct!
@autostruct function New1(a::Int) # re-definition of constructor, same struct!
A = Dense(2a => a, tanh)
New1(A)
end
Expand All @@ -33,7 +33,7 @@ end

@autostruct :expand function New1(a, b=3) # new struct, both for :expand and for b argument
A = Dense(a => b)
New1(A)
New1(A::Dense)
end

@testset "new defn" begin
Expand Down

2 comments on commit a850e1d

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118483

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.1 -m "<description of version>" a850e1da0b11867d998757a4ba4f6653d9ad6d77
git push origin v0.2.1

Please sign in to comment.