Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ci: setup benchmarking CI #821

Merged
merged 15 commits into from
Aug 19, 2024
Prev Previous commit
Next Next commit
ci: fix benchmarks
[skip tests] [skip docs]
  • Loading branch information
avik-pal committed Aug 19, 2024
commit a60981c6e5197c595ac66076f3b8ea9466327abb
9 changes: 9 additions & 0 deletions benchmarks/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
style = "sciml"
whitespace_in_kwargs = false
margin = 92
indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
always_for_in = true
annotate_untyped_fields_with_any = false
join_lines_based_on_source = true
4 changes: 2 additions & 2 deletions benchmarks/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function benchmark_forward_pass!(suite::BenchmarkGroup, group::String, tag, mode
synchronize($dev)
end setup=begin
fdev = group_to_flux_backend($group)
x = randn(StableRNG(0), Float32, $x_dims) |> fdev
x = randn(Random.default_rng(), Float32, $x_dims) |> fdev
fmodel = $(flux_model()) |> fdev
Flux.testmode!(fmodel, true)
fmodel(x) # Warm up
Expand Down Expand Up @@ -149,7 +149,7 @@ function benchmark_reverse_pass_flux!(
synchronize($dev)
end setup=begin
fdev = group_to_flux_backend($group)
x = randn(StableRNG(0), Float32, $x_dims) |> fdev
x = randn(Random.default_rng(), Float32, $x_dims) |> fdev
fmodel = $(model)() |> fdev
Zygote.gradient(sumabs2, fmodel, x) # Warm up
end
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/setups/layers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function add_dense_benchmarks!(suite::BenchmarkGroup, group::String)
for n in (16, 128, 512,), act in (identity, relu, gelu)
for n in (16, 128, 512), act in (identity, relu, gelu)
layer = Dense(n => n, act)
simple_chains = n ≤ 200 ? Lux.ToSimpleChainsAdaptor((static(n),)) : nothing
flux_model = () -> Flux.Dense(n => n, act)
Expand Down
3 changes: 0 additions & 3 deletions benchmarks/setups/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ function add_vgg16_benchmarks!(suite::BenchmarkGroup, group::String)
name="ConvBN")
end

#! format: off
vgg16 = Chain(
Chain(conv_bn((3, 3), 3 => 64, relu; pad=(1, 1), stride=(1, 1)),
conv_bn((3, 3), 64 => 64, relu; pad=(1, 1), stride=(1, 1)),
Expand All @@ -30,7 +29,6 @@ function add_vgg16_benchmarks!(suite::BenchmarkGroup, group::String)
Chain(Dense(512, 4096, relu), Dropout(0.5f0), Dense(4096, 4096, relu),
Dropout(0.5f0), Dense(4096, 10); name="Classifier"); disable_optimizations=true)


flux_model = () -> Flux.Chain(
Flux.Conv((3, 3), 3 => 64, relu; pad=(1, 1), stride=(1, 1)),
Flux.BatchNorm(64), Flux.Conv((3, 3), 64 => 64, relu; pad=(1, 1), stride=(1, 1)),
Expand All @@ -51,7 +49,6 @@ function add_vgg16_benchmarks!(suite::BenchmarkGroup, group::String)
Flux.Conv((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512),
Flux.MaxPool((2, 2)), Flux.flatten, Flux.Dense(512, 4096, relu), Flux.Dropout(0.5),
Flux.Dense(4096, 4096, relu), Flux.Dropout(0.5), Flux.Dense(4096, 10))
#! format: on

for bsize in (2, 16, 64)
benchmark_forward_pass!(
Expand Down
8 changes: 3 additions & 5 deletions src/helpers/match_eltype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,9 @@ For `"convert"` only the following conversions are done:
"""
function match_eltype end

@static if LuxPreferences.ELTYPE_MISMATCH_HANDLING == "none" # Just return the input
match_eltype(layer, ps, st, x) = x
function match_eltype(layer, ps, st, x, args...)
return (x, args...)
end
@static if ELTYPE_MISMATCH_HANDLING == "none" # Just return the input
@inline match_eltype(layer, ps, st, x) = x
@inline match_eltype(layer, ps, st, x, args...) = (x, args...)
else
function match_eltype(layer, ps, st, x)
fn = let elType = recursive_eltype((ps, st), Val(true)), layer = layer
Expand Down