Skip to content

Commit

Permalink
Automatically register GPU converters upon loading the glue package.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed May 4, 2021
1 parent e90c3fe commit 433e6a9
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 2 deletions.
9 changes: 9 additions & 0 deletions lib/FluxAMDGPU/src/FluxAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,13 @@ using Flux: OneHotArray, OneHotLike
Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:ROCArray}}) where N =
AMDGPU.ROCArrayStyle{N}()

function __init__()
if Flux.default_gpu_converter[] === identity
@info "Registering AMDGPU.jl as the default GPU converter"
Flux.default_gpu_converter[] = roc
else
@warn "Not registering AMDGPU.jl as the default GPU converter as another one has been registered already."
end
end

end # module
2 changes: 1 addition & 1 deletion lib/FluxAMDGPU/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using .AMDGPU
ENV["JULIA_GPU_ALLOWSCALAR"] = "false"

using .Flux
Flux.default_gpu_converter[] = AMDGPU.roc
@assert Flux.default_gpu_converter[] == roc

using Zygote
using Zygote: pullback
Expand Down
8 changes: 8 additions & 0 deletions lib/FluxCUDA/src/FluxCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,13 @@ include("onehot.jl")
include("ctc.jl")
include("cudnn.jl")

function __init__()
if Flux.default_gpu_converter[] === identity
@info "Registering CUDA.jl as the default GPU converter"
Flux.default_gpu_converter[] = cu
else
@warn "Not registering CUDA.jl as the default GPU converter as another one has been registered already."
end
end

end # module
2 changes: 1 addition & 1 deletion lib/FluxCUDA/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using .CUDA
ENV["JULIA_GPU_ALLOWSCALAR"] = "false"

using .Flux
Flux.default_gpu_converter[] = cu
@assert Flux.default_gpu_converter[] == cu

using Zygote
using Zygote: pullback
Expand Down

0 comments on commit 433e6a9

Please sign in to comment.