Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround Flux#1027 #4

Merged
merged 21 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
/Manifest.toml
Manifest.toml
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
name = "LegolasFlux"
uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4"
authors = ["Beacon Biosignals, Inc."]
version = "0.1.0"
version = "0.1.1"
ericphanson marked this conversation as resolved.
Show resolved Hide resolved

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Arrow = "1"
Flux = "0.12"
Legolas = "0.1, 0.2"
Tables = "1"
julia = "1.5"

[extras]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Flux"]
test = ["Test", "Flux", "StableRNGs", "Statistics", "Random"]
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
[![codecov](https://codecov.io/gh/beacon-biosignals/LegolasFlux.jl/branch/main/graph/badge.svg?token=NHYUL22HCC)](https://codecov.io/gh/beacon-biosignals/LegolasFlux.jl)

LegolasFlux provides some simple functionality to use [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/)'s
extensible Arrow schemas as means to serialize Flux models using Flux's `params` and `loadparams!`.
extensible Arrow schemas as means to serialize Flux models similarly to using Flux's `params` and `loadparams!`
(instead, we export similar functions `weights` and `loadweights!` which handle layers like `BatchNorm` correctly for this purpose).

The aim is to serialize only the numeric weights, *not* the code defining the model. This is a very different approach
from e.g. BSON.jl, and hopefully much more robust.
Expand All @@ -28,14 +29,14 @@ my_model = make_my_model()
using LegolasFlux

# We can save whatever other columns we'd like to as well as the `weights`.
model_row = ModelRow(; weights = collect(params(cpu(my_model))), architecture_version = 1, loss = 0.5)
model_row = ModelRow(; weights = collect(weights(cpu(my_model))), architecture_version = 1, loss = 0.5)
write_model_row("my_model.model.arrow", model_row)

# Great! Later on, we want to re-load our model weights.
fresh_model = make_my_model()

model_row = read_model_row("my_model.model.arrow")
Flux.loadparams!(fresh_model, collect(model_row.weights))
loadweights!(fresh_model, collect(model_row.weights))
# Now our params have been loaded back into `fresh_model`.
# Note we needed to `collect` the weights before we use them.

Expand All @@ -47,6 +48,8 @@ model_row.loss # 0.5
We can make use of the `architecture_version` column to specify a version number for the architectures, in order
to keep track of for which architectures the weights are valid for.

See [examples/digits.jl](examples/digits.jl) for a larger example.

## `LegolasFlux.ModelRow`

A `LegolasFlux.ModelRow` is the central object of LegolasFlux. It acts as a Tables.jl-compatible row that can store the weights
Expand Down Expand Up @@ -78,4 +81,5 @@ one might name files produced by this row as e.g. `training_run.digits.model.arr
Note in this example the schema is called `digits.model` instead of just say `digits`, since the package Digits might want to
create other Legolas schemas as well at some point.

Check out the [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/) repo to see more about how its extensible schema system works.
Check out the [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/) repo to see more about how its extensible schema system works,
and the example at [examples/digits.jl](examples/digits.jl).
7 changes: 7 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd"
LegolasFlux = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
141 changes: 141 additions & 0 deletions examples/digits.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Model modified from
# https://discourse.julialang.org/t/how-to-drop-the-dropout-layers-in-flux-jl-when-assessing-model-performance/19924

using Flux, Statistics, Random, Test
# Uncomment to use MNIST data
# using MLDatasets: MNIST
using StableRNGs
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Legolas, LegolasFlux

# This should store all the information needed
# to construct the model.
Base.@kwdef struct DigitsConfig
seed::Int = 5
dropout_rate::Float32 = 0f1
end

# Here's our model object itself, just a `DigitsConfig` and
# a `chain`. We keep the config around so it's easy to save out
# later.
struct DigitsModel
chain::Chain
config::DigitsConfig
end

# Ensure Flux can recurse into our model to find params etc
Flux.@functor DigitsModel (chain,)

# Construct the actual model from a config object. This is the only
# constructor that should be used, to ensure the model is created just
# from the config object alone.
function DigitsModel(config::DigitsConfig = DigitsConfig())
ericphanson marked this conversation as resolved.
Show resolved Hide resolved
dropout_rate = config.dropout_rate
Random.seed!(config.seed)
chain = Chain(
Dropout(dropout_rate),
Conv((3, 3), 1=>32, relu),
BatchNorm(32, relu),
MaxPool((2,2)),
Dropout(dropout_rate),
Conv((3, 3), 32=>16, relu),
Dropout(dropout_rate),
MaxPool((2,2)),
Dropout(dropout_rate),
Conv((3, 3), 16=>10, relu),
Dropout(dropout_rate),
x -> reshape(x, :, size(x, 4)),
Dropout(dropout_rate),
Dense(90, 10), softmax)
ericphanson marked this conversation as resolved.
Show resolved Hide resolved
return DigitsModel(chain, config)
end

# Our model acts on input just by applying the chain.
(m::DigitsModel)(x) = m.chain(x)

# Here, we define a schema extension of the `legolas-flux.model` schema.
# We add our `DigitsConfig` object, as well as the epoch and accuracy.
const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1",
config::DigitsConfig,
epoch::Union{Missing, Int},
accuracy::Union{Missing, Float32})

# Construct a `DigitsRow` from a model by collecting the `weights`.
# This can then be saved with e.g. `LegolasFlux.write_model_row`.
function DigitsRow(model::DigitsModel; epoch=missing, accuracy=missing)
w = collect(weights(model))
return DigitsRow(; weights=w, model.config, epoch, accuracy)
end

# Construct a `DigitsModel` from a row satisfying the `DigitsRow` schema,
# i.e. one with a `weights` and `config::DigitsConfig`.
# This could be the result of `LegolasFlux.read_model_row`.
function DigitsModel(row)
m = DigitsModel(row.config)
loadweights!(m, collect(row.weights))
return m
end


# Increase to get more training/test data
N_train = 1_000
N_test = 50

##
# to use MNIST data, uncomment these
# train_x, train_y = MNIST.traindata(Float32, 1:N_train)
# test_x, test_y = MNIST.testdata(Float32, 1:N_test)
ericphanson marked this conversation as resolved.
Show resolved Hide resolved

# Random data:
rng = StableRNG(735)
train_x = rand(rng, Float32, 28, 28, N_train)
train_y = rand(rng, 0:9, N_train)
test_x = rand(rng, Float32, 28, 28, N_test)
test_y = rand(rng, 0:9, N_test)
##

# Partition into batches of size 32
batch_size = 32
train = [(reshape(train_x[:, :, I], 28, 28, 1, :), onehotbatch(train_y[I], 0:9))
for I in partition(1:N_train, batch_size)]

tX = reshape(test_x, 28, 28, 1, :)
tY = onehotbatch(test_y, 0:9)

function accuracy(m, x, y)
testmode!(m)
val = mean(onecold(m(x)) .== onecold(y))
trainmode!(m)
return val
end

function train_model!(m; N = N_train)
loss = (x, y) -> crossentropy(m(x), y)
opt = ADAM()
evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5)
Flux.@epochs 1 Flux.train!(loss, params(m), Iterators.take(train, N), opt, cb = evalcb)
ericphanson marked this conversation as resolved.
Show resolved Hide resolved
return accuracy(m, tX, tY)
end

m = DigitsModel()

# increase N to actually train more than a tiny amount
acc = train_model!(m; N = 10)
ericphanson marked this conversation as resolved.
Show resolved Hide resolved

# Let's serialize out the weights into a `DigitsRow`.
# We could save this here with `write_model_row`.
row = DigitsRow(m; epoch=1, accuracy=acc)

testmode!(m)
input = tX[:, :, :, 1:1]
output = m(input)
label = tY[:, 1]

# Let's now reconstruct the model from the `row` and check that we get
# the same outputs.
m2 = DigitsModel(row)
testmode!(m2)
output2 = m2(input)

@test output ≈ output2
1 change: 1 addition & 0 deletions src/LegolasFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,6 @@ function read_model_row(io_or_path)
return only(rows)
end

include("flux_workarounds.jl")

end # module
57 changes: 57 additions & 0 deletions src/flux_workarounds.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using Flux: BatchNorm, InstanceNorm, GroupNorm, Params, trainable
using Base: IdSet
export weights, loadweights!

"""
LegolasFlux.other_weights(layer) -> Vararg{Array}

Given a layer with params that are not captured by `Flux.trainable`, produce
a tuple of arrays corresponding to these parameters (analogous to `Flux.trainable`).
"""
function other_weights end
ericphanson marked this conversation as resolved.
Show resolved Hide resolved

other_weights(layer) = ()
ericphanson marked this conversation as resolved.
Show resolved Hide resolved
other_weights(layer::BatchNorm) = (layer.μ, layer.σ²)
other_weights(layer::InstanceNorm) = (layer.μ, layer.σ²)
other_weights(layer::GroupNorm) = (layer.μ, layer.σ²)

#####
##### `weights`
#####

# The following is a copy of <https://github.com/FluxML/Flux.jl/blob/335286adf118b61ad6fffa5937bd9358477a00c9/src/functor.jl#L41-L63>
# with `params` changed to `weights` and the addition of the lines
# ```julia
# for child in other_weights(x)
# weights!(p, child, seen)
# end
# ```
# to `weights!(p::Params, x, seen = IdSet())`.

weights!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
ericphanson marked this conversation as resolved.
Show resolved Hide resolved

function weights!(p::Params, x, seen = IdSet())
x in seen && return
push!(seen, x)
for child in trainable(x)
weights!(p, child, seen)
end

for child in other_weights(x)
weights!(p, child, seen)
end
end

function weights(m...)
ps = Params()
weights!(ps, m)
return ps
end

function loadweights!(m, xs)
for (p, x) in zip(weights(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end
42 changes: 39 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@ function test_weights()
return [reshape(Float32.(1:prod(s)), s) for s in shapes]
end

@testset begin
# This simple model should work with both Flux's `params/loadparams!` and
# our `weights/loadweights!`. The only difference is in layers with `!isempty(other_weights(layer))`.
@testset "using ($get_weights, $load_weights)" for (get_weights, load_weights) in [(weights, loadweights!, params, Flux.loadparams!)]
my_model = make_my_model()
Flux.loadparams!(my_model, test_weights())

model_row = ModelRow(; weights=collect(params(my_model)))
model_row = ModelRow(; weights=collect(get_weights(my_model)))
write_model_row("my_model.model.arrow", model_row)

fresh_model = make_my_model()

model_row = read_model_row("my_model.model.arrow")
weights = collect(model_row.weights)
Flux.loadparams!(fresh_model, weights)
load_weights(fresh_model, weights)

@test collect(params(fresh_model)) == weights == test_weights()

Expand All @@ -43,3 +45,37 @@ end
tbl = [(; weights = w)]
@test Arrow.Table(Arrow.tobuffer(tbl)).weights[1] == w
end

@testset "`flux_workarounds`" begin
@testset "layer $layer" for layer in [BatchNorm, InstanceNorm, (c) -> GroupNorm(c, 1), c -> identity]
mk_model = () -> (Random.seed!(1); Chain(Dense(1, 10), Dense(10, 10), layer(1), Dense(10, 1)))
model = mk_model()
trainmode!(model)
x = reshape([1f0], 1, 1, 1)
for i = 1:10
ericphanson marked this conversation as resolved.
Show resolved Hide resolved
x = model(x)
end
testmode!(model)
w = collect(weights(model))
p = collect(params(model))
output = model(x)

r1 = mk_model()
loadweights!(r1, w)
testmode!(r1)

@test output ≈ r1(x)

if layer == BatchNorm
r2 = mk_model()
Flux.loadparams!(r2, p)
testmode!(r2)

@test_broken output ≈ r2(x)
ericphanson marked this conversation as resolved.
Show resolved Hide resolved
end
end
end

@testset "Example" begin
include("../examples/digits.jl")
end