Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround Flux#1027 #4

Merged
merged 21 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ version = "0.1.1"

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Arrow = "1"
Flux = "0.12"
Functors = "0.2.1"
Legolas = "0.1, 0.2"
Tables = "1"
julia = "1.5"
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

LegolasFlux provides some simple functionality to use [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl/)'s
extensible Arrow schemas as means to serialize Flux models similarly to using Flux's `params` and `loadparams!`
(instead, we export similar functions `weights` and `loadweights!` which handle layers like `BatchNorm` correctly for this purpose).
(instead, we export similar functions `weights` and `load_weights!` which handle layers like `BatchNorm` correctly for this purpose).

The aim is to serialize only the numeric weights, *not* the code defining the model. This is a very different approach
from e.g. BSON.jl, and hopefully much more robust.
Expand All @@ -29,16 +29,16 @@ my_model = make_my_model()
using LegolasFlux

# We can save whatever other columns we'd like to as well as the `weights`.
model_row = ModelRow(; weights = collect(weights(cpu(my_model))), architecture_version = 1, loss = 0.5)
model_row = ModelRow(; weights = weights(cpu(my_model)),
architecture_version=1, loss=0.5)
write_model_row("my_model.model.arrow", model_row)

# Great! Later on, we want to re-load our model weights.
fresh_model = make_my_model()

model_row = read_model_row("my_model.model.arrow")
loadweights!(fresh_model, collect(model_row.weights))
# Now our params have been loaded back into `fresh_model`.
# Note we needed to `collect` the weights before we use them.
load_weights!(fresh_model, model_row.weights)
# Now our weights have been loaded back into `fresh_model`.

# We can also check out our other columns:
model_row.loss # 0.5
Expand Down
41 changes: 20 additions & 21 deletions examples/digits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,24 @@ Flux.@functor DigitsModel (chain,)
# Construct the actual model from a config object. This is the only
# constructor that should be used, to ensure the model is created just
# from the config object alone.
function DigitsModel(config::DigitsConfig = DigitsConfig())
function DigitsModel(config::DigitsConfig=DigitsConfig())
dropout_rate = config.dropout_rate
Random.seed!(config.seed)
chain = Chain(
Dropout(dropout_rate),
Conv((3, 3), 1=>32, relu),
BatchNorm(32, relu),
MaxPool((2,2)),
Dropout(dropout_rate),
Conv((3, 3), 32=>16, relu),
Dropout(dropout_rate),
MaxPool((2,2)),
Dropout(dropout_rate),
Conv((3, 3), 16=>10, relu),
Dropout(dropout_rate),
x -> reshape(x, :, size(x, 4)),
Dropout(dropout_rate),
Dense(90, 10), softmax)
chain = Chain(Dropout(dropout_rate),
Conv((3, 3), 1 => 32, relu),
BatchNorm(32, relu),
MaxPool((2, 2)),
Dropout(dropout_rate),
Conv((3, 3), 32 => 16, relu),
Dropout(dropout_rate),
MaxPool((2, 2)),
Dropout(dropout_rate),
Conv((3, 3), 16 => 10, relu),
Dropout(dropout_rate),
x -> reshape(x, :, size(x, 4)),
Dropout(dropout_rate),
Dense(90, 10),
softmax)
return DigitsModel(chain, config)
end

Expand All @@ -64,16 +64,15 @@ const DigitsRow = Legolas.@row("digits.model@1" > "legolas-flux.model@1",
# Construct a `DigitsRow` from a model by collecting the `weights`.
# This can then be saved with e.g. `LegolasFlux.write_model_row`.
function DigitsRow(model::DigitsModel; epoch=missing, accuracy=missing)
w = collect(weights(model))
return DigitsRow(; weights=w, model.config, epoch, accuracy)
return DigitsRow(; weights=weights(model), model.config, epoch, accuracy)
end

# Construct a `DigitsModel` from a row satisfying the `DigitsRow` schema,
# i.e. one with a `weights` and `config::DigitsConfig`.
# This could be the result of `LegolasFlux.read_model_row`.
function DigitsModel(row)
m = DigitsModel(row.config)
loadweights!(m, collect(row.weights))
load_weights!(m, row.weights)
return m
end

Expand Down Expand Up @@ -114,14 +113,14 @@ function train_model!(m; N = N_train)
loss = (x, y) -> crossentropy(m(x), y)
opt = ADAM()
evalcb = throttle(() -> @show(accuracy(m, tX, tY)), 5)
Flux.@epochs 1 Flux.train!(loss, params(m), Iterators.take(train, N), opt, cb = evalcb)
Flux.@epochs 1 Flux.train!(loss, params(m), Iterators.take(train, N), opt; cb=evalcb)
return accuracy(m, tX, tY)
end

m = DigitsModel()

# increase N to actually train more than a tiny amount
acc = train_model!(m; N = 10)
acc = train_model!(m; N=10)

# Let's serialize out the weights into a `DigitsRow`.
# We could save this here with `write_model_row`.
Expand Down
5 changes: 4 additions & 1 deletion src/LegolasFlux.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
module LegolasFlux

export write_model_row, read_model_row
export weights, load_weights!

using Legolas
using Arrow
using Arrow.ArrowTypes
using Tables
using Functors
using Base: IdSet

const LEGOLAS_SCHEMA = Legolas.Schema("legolas-flux.model@1")

Expand Down Expand Up @@ -110,6 +113,6 @@ function read_model_row(io_or_path)
return only(rows)
end

include("flux_workarounds.jl")
include("functors.jl")

end # module
57 changes: 0 additions & 57 deletions src/flux_workarounds.jl

This file was deleted.

43 changes: 43 additions & 0 deletions src/functors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Modified version of `fcollect` to use an `IdSet` cache so that
# distinct arrays whose values happen to be duplicates are each kept.
# <https://github.com/FluxML/Functors.jl/issues/16>
function fcollect2(x; output=[], cache=IdSet(), exclude=_ -> false)
x in cache && return output
if !exclude(x)
push!(cache, x)
push!(output, x)
foreach(y -> fcollect2(y; cache=cache, output=output, exclude=exclude), Functors.children(x))
end
return output
end

"""
weights(m) -> Vector{Array}

Returns the weights of a model by using `Functors.children` to recurse
through the model, keeping any arrays found. The `@functor` macro defines
`Functors.children` automatically so that should be sufficient to support
custom types.
"""
weights(m) = filter(x -> x isa Array, fcollect2(m))

"""
load_weights!(m, xs)

Load weights `xs` into the model `m`, using [`weights`](@ref).
"""
function load_weights!(m, xs)
model_weights = weights(m)
if length(model_weights) != length(xs)
throw(ArgumentError("Number of weights given ($(length(xs))) does not match number of weights model expects ($(length(model_weights)))"))
ericphanson marked this conversation as resolved.
Show resolved Hide resolved
end
for (i, (p, x)) in enumerate(zip(model_weights, xs))
if size(p) != size(x)
throw(ArgumentError("For the $(i)th weight expected param size $(size(p)), got $(size(x))"))
end
copyto!(p, x)
end
return nothing
end

load_weights!(m, xs::Weights) = load_weights!(m, collect(xs))
29 changes: 23 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test
using Flux, LegolasFlux
using LegolasFlux: Weights, FlatArray, ModelRow
using Arrow
using Random
ericphanson marked this conversation as resolved.
Show resolved Hide resolved

function make_my_model()
return Chain(Dense(1, 10), Dense(10, 10), Dense(10, 1))
Expand All @@ -14,10 +15,10 @@ function test_weights()
end

# This simple model should work with both Flux's `params/loadparams!` and
# our `weights/loadweights!`. The only difference is in layers with `!isempty(other_weights(layer))`.
@testset "using ($get_weights, $load_weights)" for (get_weights, load_weights) in [(weights, loadweights!, params, Flux.loadparams!)]
# our `weights/load_weights!`. The only difference is in layers with `!isempty(other_weights(layer))`.
@testset "using ($get_weights, $load_weights)" for (get_weights, load_weights) in [(weights, load_weights!, params, Flux.loadparams!)]
my_model = make_my_model()
Flux.loadparams!(my_model, test_weights())
load_weights(my_model, test_weights())

model_row = ModelRow(; weights=collect(get_weights(my_model)))
write_model_row("my_model.model.arrow", model_row)
Expand All @@ -35,6 +36,17 @@ end
rm("my_model.model.arrow")
end

@testset "Errors" begin
my_model = make_my_model()
w = test_weights()
w[end] = []
@test_throws ArgumentError load_weights!(my_model, w)

w = test_weights()
push!(w, [])
@test_throws ArgumentError load_weights!(my_model, w)
end

@testset "`Weights`" begin
v = [rand(Int8, 5), rand(Float32, 5, 5)]
@test Weights(v) isa Weights{Float32}
Expand All @@ -52,16 +64,16 @@ end
model = mk_model()
trainmode!(model)
x = reshape([1f0], 1, 1, 1)
for i = 1:10
for i in 1:10
x = model(x)
end
testmode!(model)
w = collect(weights(model))
w = weights(model)
p = collect(params(model))
output = model(x)

r1 = mk_model()
loadweights!(r1, w)
load_weights!(r1, w)
testmode!(r1)

@test output ≈ r1(x)
Expand All @@ -71,6 +83,11 @@ end
Flux.loadparams!(r2, p)
testmode!(r2)

# If this test *fails*, meaning `output ≈ r2(x)`,
# then perhaps we should revisit `load_weights!`
# and could consider switching to `Flux.loadparams`.
# See https://github.com/beacon-biosignals/LegolasFlux.jl/pull/4
# for more.
@test_broken output ≈ r2(x)
ericphanson marked this conversation as resolved.
Show resolved Hide resolved
end
end
Expand Down