Skip to content

Commit

Permalink
Merge #1704
Browse files Browse the repository at this point in the history
1704: Gradient definitions for `cpu` & `gpu` r=mcabbott a=mcabbott

Closes #1695. Closes FluxML/Zygote.jl#1005

Not well tested locally, so we will see what CI thinks. 

I'm not very confident this uses `fmap` correctly. Even thornier test cases would be welcome.

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
  • Loading branch information
bors[bot] and mcabbott authored Sep 6, 2021
2 parents 7ada01c + 65e8f70 commit 2468a06
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 8 deletions.
51 changes: 43 additions & 8 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ end
"""
cpu(m)
Moves `m` onto the CPU.
This utility uses [`@functor`](@ref) to properly move structures to the CPU.
Moves `m` onto the CPU, the opposite of [`gpu`](@ref).
Recurses into structs marked [`@functor`](@ref).
```julia-repl
julia> m = Dense(1,2)
Expand All @@ -86,7 +85,24 @@ julia> typeof(m_cpu.W)
Matrix{Float32}
```
"""
cpu(m) = fmap(x -> adapt(Array, x), m)
cpu(x) = fmap(_cpu_array, x; exclude = _isbitsarray)

_cpu_array(x::AbstractArray) = adapt(Array, x)
# adapt(Array, x) materialises some lazy arrays, on which cpu() should do nothing:
_cpu_array(x::AbstractRange) = x
_cpu_array(x::Zygote.FillArrays.AbstractFill) = x
_cpu_array(x::Zygote.OneElement) = x

function Zygote.ChainRules.rrule(::typeof(_cpu_array), x::AbstractArray)
y = _cpu_array(x)
if x === y
# Trivial use: cpu(x::Array) shouldn't push its gradient to GPU
return y, dy -> (Zygote.ChainRules.NoTangent(), dy)
else
# Allows both cpu(x::CuArray) and cpu(x::Adjoint{T,CuArray}):
return y, dy -> (Zygote.ChainRules.NoTangent(), _gpu_array(dy))
end
end

_isbitsarray(::AbstractArray{<:Number}) = true
_isbitsarray(::AbstractArray{T}) where T = isbitstype(T)
Expand All @@ -99,8 +115,7 @@ Moves `m` to the current GPU device, if available. It is a no-op otherwise.
See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/)
to help identify the current device.
This works for functions and
any struct with [`@functor`](@ref) defined.
This works for functions, and any struct marked with [`@functor`](@ref).
```julia-repl
julia> m = Dense(1,2)
Expand All @@ -116,11 +131,31 @@ julia> typeof(m_gpu.W) # notice the type of the array changed to a CuArray
CuArray{Float32, 2}
```
"""
gpu(x) = use_cuda[] ? fmap(CUDA.cu, x; exclude = _isbitsarray) : x
gpu(x) = use_cuda[] ? fmap(_gpu_array, x; exclude = _isbitsarray) : x

_gpu_array(x::AbstractArray) = CUDA.cu(x)

# While `cu` moves Arrays to the GPU, we also want to move some structured arrays
# https://github.com/FluxML/Zygote.jl/issues/1005
_gpu_array(x::Zygote.FillArrays.AbstractFill) = CUDA.fill(first(x), size(x)) # gradient of sum
function _gpu_array(x::Zygote.OneElement) # gradient of getindex
y = CUDA.zeros(eltype(x), size(x))
CUDA.@allowscalar y[x.ind...] = x.val
y
end

function Zygote.ChainRules.rrule(::typeof(_gpu_array), x::AbstractArray)
y = _gpu_array(x)
if x === y # trivial case, e.g. gpu(x::Adjoint{T,CuArray})
return y, dy -> (Zygote.ChainRules.NoTangent(), dy)
else
return y, dy -> (Zygote.ChainRules.NoTangent(), _cpu_array(dy))
end
end

# Precision

adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs) # piracy

paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)

Expand Down
55 changes: 55 additions & 0 deletions test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,58 @@ end
@test gpu((;a=[SimpleBits(1)])).a isa CuVector{SimpleBits}
end
end

@testset "gpu(cpu(x)) inside gradient" begin
a = randn(Float32, 4, 4)
ca = cu(a)

# Trivial functions
@test gradient(x -> sum(abs, gpu(x)), a)[1] isa Matrix
@test gradient(x -> sum(gpu(x)), a)[1] isa Matrix
@test_skip gradient(x -> sum(gpu(x)), a')[1] isa Matrix # sum(::Adjoint{T,CuArray}) makes a Fill
@test gradient(x -> sum(abs, cpu(x)), ca)[1] isa CuArray
@test gradient(x -> sum(cpu(x)), ca)[1] isa CuArray # This involves FillArray, moved to GPU
@test gradient(x -> sum(cpu(x)), ca')[1] isa CuArray

# Even more trivial: no movement
@test gradient(x -> sum(abs, cpu(x)), a)[1] isa Matrix
@test gradient(x -> sum(abs, cpu(x)), a')[1] isa Matrix
@test gradient(x -> sum(cpu(x)), a)[1] isa typeof(gradient(sum, a)[1]) # FillArray
@test gradient(x -> sum(abs, gpu(x)), ca)[1] isa CuArray
@test_skip gradient(x -> sum(abs, gpu(x)), ca')[1] isa CuArray # KernelError: passing and using non-bitstype argument

# More complicated, Array * CuArray is an error
g0 = gradient(x -> sum(abs, (a * (a * x))), a)[1]
@test g0 gradient(x -> sum(abs, cpu(ca * gpu(a * x))), a)[1]
@test cu(g0) gradient(x -> sum(abs, gpu(a * cpu(ca * x))), ca)[1]

g4 = gradient(x -> sum(a * (a' * x)), a)[1] # no abs, one adjoint
@test g4 gradient(x -> sum(cpu(ca * gpu(a' * x))), a)[1]
@test cu(g4) gradient(x -> sum(gpu(a * cpu(ca' * x))), ca)[1]

# Scalar indexing of an array, needs OneElement to transfer to GPU
# https://github.com/FluxML/Zygote.jl/issues/1005
@test gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3]) == ([2,0,0],)
@test gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9]) == ([2 6 8; 0 2 0; 0 3 0],)

end
@testset "gpu(x) and cpu(x) on structured arrays" begin
# Check first that cpu() is a no-op on these, which adapt(Array, x) presently is not:
@test cpu(1:3) isa UnitRange
@test cpu(range(1,3,length=4)) isa AbstractRange
g1 = Zygote.OneElement(1, (2,3), axes(ones(4,5)))
@test cpu(g1) isa Zygote.OneElement
g2 = Zygote.Fill(1f0,2)
@test cpu(g2) isa Zygote.Fill
g3 = transpose(Float32[1 2; 3 4])
@test parent(cpu(g3)) isa Matrix

# Check that gpu() converts these to CuArrays. This a side-effect of using the same functions
# in gpu() as in the gradient of cpu(). A different design could avoid having gpu() used alone
# move these, if that turns out to be desirable.
@test gpu(g1) isa CuArray
@test gpu(g1) cu(Matrix(g1))
@test gpu(g2) isa CuArray
@test gpu(g2) cu(Vector(g2))
@test parent(gpu(g3)) isa CuArray
end

0 comments on commit 2468a06

Please sign in to comment.