-
-
Notifications
You must be signed in to change notification settings - Fork 608
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
466 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
name = "FluxAMDGPU" | ||
uuid = "15448036-796b-45b3-936c-e3e32bc623ba" | ||
authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
|
||
[compat] | ||
AMDGPU = "0.2" | ||
Flux = "0.12" | ||
julia = "1.6" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
module FluxAMDGPU | ||
|
||
using Flux | ||
using AMDGPU | ||
|
||
### onehot | ||
|
||
import Flux: OneHotArray, OneHotLike, _onehot_bool_type | ||
|
||
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:ROCArray}) where N = ROCArray{Bool, N} | ||
Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:ROCArray}}) where N = AMDGPU.ROCArrayStyle{N}() | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[deps] | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import .Flux: cpu, gpu | ||
using Statistics: mean | ||
using LinearAlgebra: I, cholesky, Cholesky | ||
|
||
@testset "Core" begin | ||
x = randn(5, 5) | ||
cx = gpu(x) | ||
@test cx isa ROCArray | ||
|
||
@test Flux.onecold(gpu([1.0, 2.0, 3.0])) == 3 | ||
|
||
x = Flux.onehotbatch([1, 2, 3], 1:3) | ||
cx = gpu(x) | ||
@test cx isa Flux.OneHotMatrix && cx.indices isa ROCArray | ||
@test (cx .+ 1) isa ROCArray | ||
|
||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) | ||
cm = gpu(m) | ||
|
||
@test all(p isa ROCArray for p in params(cm)) | ||
@test cm(gpu(rand(10, 10))) isa ROCArray{Float32,2} | ||
|
||
xs = rand(5, 5) | ||
ys = Flux.onehotbatch(1:5,1:5) | ||
@test collect(roc(xs) .+ roc(ys)) ≈ collect(xs .+ ys) | ||
|
||
c = gpu(Conv((2,2),3=>4)) | ||
x = gpu(rand(10, 10, 3, 2)) | ||
l = c(gpu(rand(10,10,3,2))) | ||
@test gradient(x -> sum(c(x)), x)[1] isa ROCArray | ||
|
||
c = gpu(CrossCor((2,2),3=>4)) | ||
x = gpu(rand(10, 10, 3, 2)) | ||
l = c(gpu(rand(10,10,3,2))) | ||
@test gradient(x -> sum(c(x)), x)[1] isa ROCArray | ||
|
||
end | ||
|
||
@testset "onecold gpu" begin | ||
y = Flux.onehotbatch(ones(3), 1:10) |> gpu; | ||
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] | ||
@test Flux.onecold(y) isa ROCArray | ||
@test y[3,:] isa ROCArray | ||
@test Flux.onecold(y, l) == ['a', 'a', 'a'] | ||
end | ||
|
||
@testset "restructure gpu" begin | ||
dudt = Dense(1,1) |> gpu | ||
p,re = Flux.destructure(dudt) | ||
foo(x) = sum(re(p)(x)) | ||
@test gradient(foo, roc(rand(1)))[1] isa ROCArray | ||
end | ||
|
||
@testset "GPU functors" begin | ||
@testset "Cholesky" begin | ||
M = 2.0*I(10) |> collect | ||
Q = cholesky(M) | ||
Q_gpu = Q |> gpu | ||
@test Q_gpu isa Cholesky{<:Any,<:ROCArray} | ||
Q_cpu = Q_gpu |> cpu | ||
@test Q_cpu == cholesky(eltype(Q_gpu).(M)) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
# Test layers and data/model movements on and off the GPU | ||
# Add tests for layers and their gradients on the GPU | ||
# Most of the forward passes should be fine being applied | ||
# to bitstype objects, but this gives higher coverage for our use-cases | ||
# Check that getting the gradients does not throw | ||
|
||
# generic movement tests | ||
@testset "Basic GPU Movement" begin | ||
@test gradient(x -> sum(gpu(x)), rand(3,3)) isa Tuple | ||
@test gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple | ||
end | ||
|
||
# TODO: These layers get into scalar indexing | ||
# `AlphaDropout` throws a compilation error on GPUs, | ||
# whereas, the rest are scalar indexing issues. | ||
const BROKEN_LAYERS = Union{DepthwiseConv, | ||
AlphaDropout} | ||
|
||
function gpu_gradtest(name::String, layers::Vector, x_cpu, args...; | ||
setmode=false, test_cpu=true, rtol=1e-5, atol=1e-5) | ||
@testset "$name GPU grad tests" begin | ||
for layer in layers | ||
@testset "$layer GPU grad test" begin | ||
l_cpu = layer(args...) | ||
if l_cpu isa BROKEN_LAYERS | ||
l_gpu, x_gpu = l_cpu |> gpu, x_cpu |> gpu | ||
@test_broken gradient(() -> sum(l_gpu(x_gpu)), Flux.params(l_gpu)) isa Flux.Zygote.Grads | ||
else | ||
gpu_autodiff_test(l_cpu, x_cpu, | ||
test_equal=test_cpu, rtol=rtol, atol=atol) | ||
if setmode | ||
testmode!(l_cpu) | ||
gpu_autodiff_test(l_cpu, x_cpu, | ||
test_equal=test_cpu, rtol=rtol, atol=atol) | ||
end | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
|
||
# Just to give testset in gradtest meaningful labels | ||
ConvNoBias(args...) = Conv(args...; bias=false) | ||
ConvTransposeNoBias(args...) = ConvTranspose(args...; bias=false) | ||
CrossCorNoBias(args...) = CrossCor(args...; bias=false) | ||
DepthwiseConvNoBias(args...) = DepthwiseConv(args...; bias=false) | ||
r = rand(Float32, 28, 28, 1, 1) | ||
conv_layers = [Conv, ConvNoBias, ConvTranspose, ConvTransposeNoBias, CrossCor, CrossCorNoBias, DepthwiseConv, DepthwiseConvNoBias] | ||
gpu_gradtest("Conv", conv_layers, r, (2,2), 1=>3) | ||
|
||
pooling_layers = [MaxPool, MeanPool] | ||
gpu_gradtest("Pooling", pooling_layers, r, (2,2)) | ||
|
||
adaptive_pooling_layers = [AdaptiveMaxPool, AdaptiveMeanPool] | ||
gpu_gradtest("AdaptivePooling", adaptive_pooling_layers, r, (7,7)) | ||
|
||
dropout_layers = [Dropout, AlphaDropout] | ||
gpu_gradtest("Dropout", dropout_layers, r, 0.5f0; test_cpu=false, setmode=true) # dropout is not deterministic | ||
|
||
layer_norm = [i -> LayerNorm(i; affine=false), i -> LayerNorm(i; affine=true)] | ||
gpu_gradtest("LayerNorm 1", layer_norm, rand(Float32, 8, 8, 3, 4), 8) | ||
gpu_gradtest("LayerNorm 2", layer_norm, rand(Float32, 8, 8, 3, 4), (8,8)) | ||
gpu_gradtest("LayerNorm 3", layer_norm, rand(Float32, 5, 4), 5) | ||
|
||
batch_norm = [BatchNorm] | ||
gpu_gradtest("BatchNorm 3d", batch_norm, rand(Float32, 8, 8, 8, 3, 4), 3, setmode=false) # bug in CUDA.jl with gradient in testmode | ||
gpu_gradtest("BatchNorm 2d", batch_norm, rand(Float32, 8, 8, 3, 4), 3, setmode=false) # bug in CUDA.jl with gradient in testmode | ||
gpu_gradtest("BatchNorm 1d", batch_norm, rand(Float32, 8, 3, 4), 3, setmode=false) # bug in CUDA.jl with gradient in testmode | ||
gpu_gradtest("BatchNorm fullyconn", batch_norm, rand(Float32, 5,4), 5, setmode=false) | ||
|
||
instancenorm = [i -> InstanceNorm(i; affine=false), i -> InstanceNorm(i; affine=true)] | ||
gpu_gradtest("InstanceNorm 3d", instancenorm, rand(Float32, 8, 8, 8, 3, 4), 3, setmode=true) | ||
gpu_gradtest("InstanceNorm 2d", instancenorm, rand(Float32, 8, 8, 3, 4), 3, setmode=true) | ||
gpu_gradtest("InstanceNorm 1d", instancenorm, rand(Float32, 8, 3, 4), 3, setmode=true) | ||
|
||
groupnorm = [(i, j) -> GroupNorm(i, j; affine=false), (i, j) -> GroupNorm(i, j; affine=true)] | ||
gpu_gradtest("GroupNorm 3d", groupnorm, rand(Float32, 8, 8, 8, 12, 4), 12, 3, setmode=true) | ||
gpu_gradtest("GroupNorm 2d", groupnorm, rand(Float32, 8, 8, 12, 4), 12, 3, setmode=true) | ||
gpu_gradtest("GroupNorm 1d", groupnorm, rand(Float32, 8, 3, 12, 4), 12, 3, setmode=true) | ||
|
||
upsample = [x -> Upsample(scale=x)] | ||
gpu_gradtest("Upsample 2d", upsample, rand(Float32, 3, 4, 2, 3), (2,2)) | ||
gpu_gradtest("Upsample 1d", upsample, rand(Float32, 3, 4, 2, 3), (2,)) | ||
|
||
pixelshuffle = [PixelShuffle] | ||
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3) | ||
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3) | ||
|
||
|
||
@testset "function layers" begin | ||
x = rand(Float32, 3,3) | ||
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x) | ||
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=2)), x) | ||
gpu_autodiff_test(x -> sum(Flux.normalise(x)), x) | ||
end | ||
|
||
@testset "BatchNorm mix stuff" begin | ||
m_cpu = BatchNorm(2) | ||
m_gpu = m_cpu |> gpu | ||
x_cpu = rand(Float32, 3, 2, 2) | ||
x_gpu = x_cpu |> gpu | ||
|
||
## In :auto mode, track statistics only in gradient contest | ||
μ_cpu = copy(m_cpu.μ) | ||
m_cpu(x_cpu) | ||
@test m_cpu.μ ≈ μ_cpu | ||
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) | ||
@test !(m_cpu.μ ≈ μ_cpu) | ||
|
||
μ_gpu = copy(m_gpu.μ) | ||
m_gpu(x_gpu) | ||
@test m_gpu.μ ≈ μ_gpu | ||
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) | ||
@test !(m_gpu.μ ≈ μ_gpu) | ||
|
||
@test Array(m_gpu.μ) ≈ m_cpu.μ | ||
|
||
## In testmode, never track statistics | ||
testmode!(m_cpu) | ||
μ_cpu = copy(m_cpu.μ) | ||
m_cpu(x_cpu) | ||
@test m_cpu.μ ≈ μ_cpu | ||
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) | ||
@test m_cpu.μ ≈ μ_cpu | ||
|
||
testmode!(m_gpu) | ||
μ_gpu = copy(m_gpu.μ) | ||
m_gpu(x_gpu) | ||
@test m_gpu.μ ≈ μ_gpu | ||
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) | ||
@test m_gpu.μ ≈ μ_gpu | ||
|
||
## In trainmode, always track statistics | ||
trainmode!(m_cpu) | ||
μ_cpu = copy(m_cpu.μ) | ||
m_cpu(x_cpu) | ||
@test !(m_cpu.μ ≈ μ_cpu) | ||
μ_cpu = copy(m_cpu.μ) | ||
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) | ||
@test !(m_cpu.μ ≈ μ_cpu) | ||
|
||
trainmode!(m_gpu) | ||
μ_gpu = copy(m_gpu.μ) | ||
m_gpu(x_gpu) | ||
@test !(m_gpu.μ ≈ μ_gpu) | ||
μ_gpu = copy(m_gpu.μ) | ||
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) | ||
@test !(m_gpu.μ ≈ μ_gpu) | ||
|
||
## No errors if input type mistmatch | ||
x_cpu = rand(Float64, 3, 2, 2) | ||
x_gpu = x_cpu |> gpu | ||
m_cpu(x_cpu) | ||
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) | ||
m_gpu(x_gpu) | ||
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) | ||
end | ||
|
||
@testset "Zeros mapped for $cl" for cl in (Conv, ConvTranspose, CrossCor, DepthwiseConv) | ||
l = cl((2,2), 1=>3, bias = false) |> gpu | ||
ip = zeros(Float32, 28,28,1,1) |> gpu | ||
if l isa BROKEN_LAYERS | ||
@test_broken sum(l(ip)) ≈ 0.f0 | ||
@test_broken gradient(() -> sum(l(ip)), Flux.params(l)) isa Flux.Zygote.Grads | ||
else | ||
@test sum(l(ip)) ≈ 0.f0 | ||
gs = gradient(() -> sum(l(ip)), Flux.params(l)) | ||
@test l.bias ∉ gs.params | ||
end | ||
end | ||
|
||
@testset "Dense with Zeros bias" begin | ||
l = Dense(ones(Float32, 4,3), Flux.Zeros()) |> gpu | ||
ip = zeros(Float32, 3, 7) |> gpu | ||
|
||
@test sum(l(ip)) ≈ 0.f0 | ||
gs = gradient(() -> sum(l(ip)), Flux.params(l)) | ||
@test l.b ∉ gs.params | ||
end | ||
|
||
@testset "Two-streams Bilinear" begin | ||
x = zeros(Float32,10,9) |> gpu | ||
y = zeros(Float32,2,9) |> gpu | ||
b = Flux.Bilinear(10, 2, 3) |> gpu | ||
@test size(b(x,y)) == (3,9) | ||
@test sum(abs2, b(x,y)) ≈ 0f0 | ||
gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b)) | ||
b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu | ||
gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu)) | ||
for (pgpu, pcpu) in zip(params(b), params(b_cpu)) | ||
@test gs_cpu[pcpu] ≈ Array(gs_gpu[pgpu]) | ||
end | ||
end | ||
|
||
@testset "Parallel" begin | ||
@testset "zero sum" begin | ||
input = randn(10, 10, 10, 10) |> gpu | ||
layer_gpu = Parallel(+, zero, identity) |> gpu | ||
@test layer_gpu(input) == input | ||
@test layer_gpu(input) isa ROCArray | ||
end | ||
|
||
@testset "vararg input" begin | ||
inputs = (randn(10), randn(5), randn(4)) .|> gpu | ||
layer = Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2)) |> gpu | ||
@test size(layer(inputs)) == (2,) | ||
end | ||
|
||
@testset "gradient" begin | ||
input_cpu = randn(10, 10, 10, 10) | ||
input_gpu = input_cpu |> gpu | ||
layer_cpu = Parallel(+, x -> zero(x), identity) | ||
layer_gpu = layer_cpu |> gpu | ||
gs_cpu = gradient(() -> sum(abs2.(layer_cpu(input_cpu))), params(layer_cpu)) | ||
gs_gpu = gradient(() -> sum(abs2.(layer_gpu(input_gpu))), params(layer_gpu)) | ||
for (pgpu, pcpu) in zip(params(layer_cpu), params(layer_gpu)) | ||
@test gs_cpu[pcpu] ≈ gs_gpu[pgpu] | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
using .Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy, binary_focal_loss, focal_loss | ||
|
||
|
||
@testset "Losses" begin | ||
|
||
x = [1.,2.,3.] | ||
cx = gpu(x) | ||
@test crossentropy(x,x) ≈ crossentropy(cx,cx) | ||
@test crossentropy(x,x, agg=identity) ≈ crossentropy(cx,cx, agg=identity) |> cpu | ||
@test crossentropy(x,x, agg=x->mean([1.0;2.0;3.0].*x)) ≈ crossentropy(cx,cx, agg=x->mean(gpu([1.0;2.0;3.0]).*x)) | ||
|
||
x = [-1.1491, 0.8619, 0.3127] | ||
y = [1, 1, 0.] | ||
@test binarycrossentropy(σ.(x), y) ≈ binarycrossentropy(gpu(σ.(x)), gpu(y)) | ||
@test logitbinarycrossentropy(x, y) ≈ logitbinarycrossentropy(gpu(x), gpu(y)) | ||
|
||
x = [0.268941 0.5 0.268941 | ||
0.731059 0.5 0.731059] | ||
y = [0 1 0 | ||
1 0 1] | ||
@test binary_focal_loss(x, y) ≈ binary_focal_loss(gpu(x), gpu(y)) | ||
|
||
x = softmax(reshape(-7:7, 3, 5) .* 1f0) | ||
y = [1 0 0 0 1 | ||
0 1 0 1 0 | ||
0 0 1 0 0] | ||
@test focal_loss(x, y) ≈ focal_loss(gpu(x), gpu(y)) | ||
|
||
@testset "GPU grad tests" begin | ||
x = rand(Float32, 3,3) | ||
y = rand(Float32, 3,3) | ||
|
||
for loss in ALL_LOSSES | ||
gpu_autodiff_test(loss, x, y) | ||
end | ||
end | ||
|
||
end #testset |
Oops, something went wrong.