Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: BFloat16 testing #115

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ LuxLibcuDNNExt = ["CUDA", "cuDNN"]
AMDGPU = "0.9.6"
Aqua = "0.8.7"
ArrayInterface = "7.9"
BFloat16s = "0.5.0"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.16"
Expand Down Expand Up @@ -86,6 +87,7 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Expand All @@ -104,4 +106,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"]
test = ["Aqua", "BFloat16s", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"]
7 changes: 3 additions & 4 deletions test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus,
logsigmoid, gelu, swish, lisht, tanh, tanh_fast],
T in [Float16, Float32, Float64]
T in [BFloat16, Float32, Float64]

x = rand(rng, T, 4, 3) |> aType

y1 = apply_act(f, x)
y2 = apply_act_fast(f, x)
y3 = apply_act_fast2(f, x)

fp16 = T == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test y1≈y2 atol=atol rtol=rtol
@test y1≈y3 atol=atol rtol=rtol
Expand Down
14 changes: 6 additions & 8 deletions test/common_ops/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module ConvSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s

_expand(N, i::Tuple) = i
_expand(N, i::Integer) = ntuple(_ -> i, N)
Expand Down Expand Up @@ -28,9 +28,8 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,

y_generic = LuxLib._generic_conv_bias_activation(activation, weight, x, bias, cdims)

fp16 = Tx == Float16 || Tw == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3
# Operation reordering has an effect on the accuracy of the results
@test y≈y_generic atol=atol rtol=rtol
@test eltype(y) == promote_type(Tw, Tx)
Expand Down Expand Up @@ -61,14 +60,13 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,
mp && push!(skip_backends, AutoReverseDiff())
((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) &&
push!(skip_backends, AutoTracker())
test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends,
soft_fail=(fp16 ? [AutoFiniteDiff()] : []))
test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends)
end

anonact = x -> gelu(x)

const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32),
(Float32, Float64), (Float64, Float64)]
const ELTYPES = [(BFloat16, BFloat16), (Float32, BFloat16),
(Float32, Float32), (Float32, Float64), (Float64, Float64)]
const ACTIVATIONS = [
identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact]

Expand Down
15 changes: 6 additions & 9 deletions test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module DenseSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s

anonact = x -> x^3

Expand All @@ -25,24 +25,21 @@ function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode
@test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true
end

fp16 = Tx == Float16 || Tw == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

skip_backends = []
Tw != Tx && push!(skip_backends, AutoReverseDiff())
fp16 && push!(skip_backends, AutoFiniteDiff())

__f_grad = let activation = activation
(w, x, b) -> __f(activation, w, x, b)
end
test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends,
soft_fail=(fp16 ? [AutoFiniteDiff()] : []))
test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends)
end

const ALL_TEST_CONFIGS = Iterators.product(
((Float16, Float16), (Float32, Float16), (Float32, Float32),
(Float32, Float64), (Float64, Float64)),
((BFloat16, BFloat16), (Float32, BFloat16),
(Float32, Float32), (Float32, Float64), (Float64, Float64)),
(4, 8),
(4, 8),
(true, false),
Expand Down
26 changes: 8 additions & 18 deletions test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$T: $x_shape" for T in (Float16, Float32, Float64),
@testset "$T: $x_shape" for T in (BFloat16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))

x = randn(rng, T, x_shape) |> aType
Expand All @@ -26,9 +26,7 @@
__f = let rng = rng, T = T
x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon())))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)

y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon())

Expand All @@ -48,7 +46,7 @@ end
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$T: $x_shape" for T in (Float16, Float32, Float64),
@testset "$T: $x_shape" for T in (BFloat16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))

x = randn(rng, T, x_shape) |> aType
Expand Down Expand Up @@ -76,9 +74,7 @@ end
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
Expand Down Expand Up @@ -106,9 +102,7 @@ end
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
Expand Down Expand Up @@ -137,9 +131,7 @@ end
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
Expand All @@ -165,7 +157,7 @@ end
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$T: $x_shape" for T in (Float16, Float32, Float64),
@testset "$T: $x_shape" for T in (BFloat16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))

x = randn(rng, T, x_shape) |> aType
Expand All @@ -186,9 +178,7 @@ end
__f = let rng = rng
x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
@test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any
Expand Down
24 changes: 5 additions & 19 deletions test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module BatchNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s

function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool)
x = gen_f(T, sz) |> aType
Expand Down Expand Up @@ -41,9 +41,8 @@ function run_batchnorm_testing(
y_simple, nt_simple = __batchnorm_basic(
x, scale, bias, rm, rv, training, act, T(0.9), epsilon)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test y≈y_simple atol=atol rtol=rtol
if track_stats
Expand Down Expand Up @@ -82,22 +81,9 @@ function run_batchnorm_testing(
skip_backends = []
act === relu && push!(skip_backends, AutoFiniteDiff())

soft_fail = if fp16
if Sys.iswindows()
[AutoTracker(), AutoFiniteDiff(), AutoReverseDiff(), AutoForwardDiff()]
else
true
end
else
false
end

broken_backends = Sys.iswindows() && fp16 ? [AutoEnzyme()] : []

__f = (args...) -> sum(first(batchnorm(
args..., rm, rv, training, act, T(0.9), epsilon)))
test_gradients(
__f, x, scale, bias; atol, rtol, skip_backends, soft_fail, broken_backends)
test_gradients(__f, x, scale, bias; atol, rtol, skip_backends)
end

if anonact !== act
Expand All @@ -109,7 +95,7 @@ function run_batchnorm_testing(
end

const ALL_TEST_CONFIGS = Iterators.product(
[Float16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)),
[BFloat16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)),
(Val(true), Val(false)), (true, false), (true, false),
(identity, relu, tanh_fast, sigmoid_fast, anonact))

Expand Down
23 changes: 10 additions & 13 deletions test/normalization/groupnorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module GroupNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s

function _setup_groupnorm(gen_f, aType, T, sz)
x = gen_f(T, sz) |> aType
Expand Down Expand Up @@ -34,20 +34,17 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu)

y_simple = _f2(x, scale, bias)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test y≈y_simple atol=atol rtol=rtol

# Check the rrules
if !fp16
∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias)
@test ∂x≈∂x_simple atol=atol rtol=rtol
@test ∂scale≈∂scale_simple atol=atol rtol=rtol
@test ∂bias≈∂bias_simple atol=atol rtol=rtol
end
∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias)
@test ∂x≈∂x_simple atol=atol rtol=rtol
@test ∂scale≈∂scale_simple atol=atol rtol=rtol
@test ∂bias≈∂bias_simple atol=atol rtol=rtol

@test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any
@jet groupnorm(x, scale, bias, groups, act, epsilon)
Expand All @@ -61,11 +58,11 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu)
@test size(y) == sz

__f = (args...) -> sum(groupnorm(args..., groups, act, epsilon))
soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
soft_fail = [AutoFiniteDiff()]
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
end

const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64],
const ALL_TEST_CONFIGS = Iterators.product([BFloat16, Float32, Float64],
((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2),
(4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)),
(2, 3),
Expand Down
23 changes: 10 additions & 13 deletions test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module InstanceNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s

__is_training(::Val{training}) where {training} = training

Expand All @@ -21,20 +21,17 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp

y_simple, nt_simple = instancenorm(x, scale, bias, training, act, epsilon)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test y≈y_simple atol=atol rtol=rtol

# Check the rrules
if !fp16
∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f, x, scale, bias)
@test ∂x≈∂x_simple atol=atol rtol=rtol
@test ∂scale≈∂scale_simple atol=atol rtol=rtol
@test ∂bias≈∂bias_simple atol=atol rtol=rtol
end
∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f, x, scale, bias)
@test ∂x≈∂x_simple atol=atol rtol=rtol
@test ∂scale≈∂scale_simple atol=atol rtol=rtol
@test ∂bias≈∂bias_simple atol=atol rtol=rtol

@test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any
@jet instancenorm(x, scale, bias, training, act, epsilon)
Expand All @@ -49,13 +46,13 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp

if __is_training(training)
__f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon)))
soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
soft_fail = [AutoFiniteDiff()]
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
end
end

const ALL_TEST_CONFIGS = Iterators.product(
[Float16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)),
[BFloat16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)),
(Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact))

const TEST_BLOCKS = collect(Iterators.partition(
Expand Down
11 changes: 5 additions & 6 deletions test/normalization/layernorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module LayerNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics, BFloat16s
using LuxTestUtils: check_approx

function _setup_layernorm(gen_f, aType, T, x_size, affine_shape)
Expand Down Expand Up @@ -33,11 +33,10 @@ function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu
@test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1)
end

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
soft_fail = [AutoFiniteDiff()]
if affine_shape !== nothing
__f = (args...) -> sum(_f(args...))
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
Expand All @@ -56,7 +55,7 @@ anonact = x -> x^3

const ALL_TEST_CONFIGS = Any[]

for T in (Float16, Float32, Float64),
for T in (BFloat16, Float32, Float64),
x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)),
affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])),
act in (identity, relu, tanh_fast, sigmoid_fast, anonact)
Expand Down
2 changes: 1 addition & 1 deletion test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import Reexport: @reexport

using LuxLib, MLDataDevices
@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote
@reexport using BFloat16s, LuxTestUtils, StableRNGs, Test, Enzyme, Zygote

LuxTestUtils.jet_target_modules!(["LuxLib"])

Expand Down
Loading