diff --git a/lib/FluxAMDGPU/src/FluxAMDGPU.jl b/lib/FluxAMDGPU/src/FluxAMDGPU.jl index 844ba954af..acffd3b4c8 100644 --- a/lib/FluxAMDGPU/src/FluxAMDGPU.jl +++ b/lib/FluxAMDGPU/src/FluxAMDGPU.jl @@ -10,4 +10,13 @@ using Flux: OneHotArray, OneHotLike Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:ROCArray}}) where N = AMDGPU.ROCArrayStyle{N}() +function __init__() + if Flux.default_gpu_converter[] === identity + @info "Registering AMDGPU.jl as the default GPU converter" + Flux.default_gpu_converter[] = roc + else + @warn "Not registering AMDGPU.jl as the default GPU converter as another one has been registered already." + end +end + end # module diff --git a/lib/FluxAMDGPU/test/runtests.jl b/lib/FluxAMDGPU/test/runtests.jl index 7a030438d6..8b6c3a65f8 100644 --- a/lib/FluxAMDGPU/test/runtests.jl +++ b/lib/FluxAMDGPU/test/runtests.jl @@ -9,7 +9,7 @@ using .AMDGPU ENV["JULIA_GPU_ALLOWSCALAR"] = "false" using .Flux -Flux.default_gpu_converter[] = AMDGPU.roc +@assert Flux.default_gpu_converter[] == roc using Zygote using Zygote: pullback diff --git a/lib/FluxCUDA/src/FluxCUDA.jl b/lib/FluxCUDA/src/FluxCUDA.jl index 75de37eb34..2936118bff 100644 --- a/lib/FluxCUDA/src/FluxCUDA.jl +++ b/lib/FluxCUDA/src/FluxCUDA.jl @@ -9,5 +9,13 @@ include("onehot.jl") include("ctc.jl") include("cudnn.jl") +function __init__() + if Flux.default_gpu_converter[] === identity + @info "Registering CUDA.jl as the default GPU converter" + Flux.default_gpu_converter[] = cu + else + @warn "Not registering CUDA.jl as the default GPU converter as another one has been registered already." + end +end end # module diff --git a/lib/FluxCUDA/test/runtests.jl b/lib/FluxCUDA/test/runtests.jl index f92f8f8522..233ba77df4 100644 --- a/lib/FluxCUDA/test/runtests.jl +++ b/lib/FluxCUDA/test/runtests.jl @@ -9,7 +9,7 @@ using .CUDA ENV["JULIA_GPU_ALLOWSCALAR"] = "false" using .Flux -Flux.default_gpu_converter[] = cu +@assert Flux.default_gpu_converter[] == cu using Zygote using Zygote: pullback