Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some issues with Zeros #1374

Closed
wants to merge 10 commits into from
5 changes: 3 additions & 2 deletions src/functor.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Adapt: adapt, adapt_storage
import Adapt: adapt, adapt_storage, adapt_structure
using LinearAlgebra: Cholesky
using Zygote: IdSet
import Functors: @functor, functor, fmap
Expand Down Expand Up @@ -61,7 +61,7 @@ mapleaves(f, x) = fmap(f, x)

function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
(any(y -> y isa Zeros, (p, x)) || size(p) == size(x)) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
Expand All @@ -78,6 +78,7 @@ gpu(x) = use_cuda[] ? fmap(CUDA.cu, x) : x
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)

paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
adapt_structure(to, x::Flux.Zeros) = x # So that we for example don't accidentally end up with layers having a scalar for bias

f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)
Expand Down
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ end
function _restructure(m, xs)
i = 0
fmap(m) do x
x isa Zeros && return x
x isa AbstractArray || return x
x = reshape(xs[i.+(1:length(x))], size(x))
i += length(x)
Expand Down Expand Up @@ -407,6 +408,7 @@ modifications to the weight vector (for example, with a hypernetwork).
function destructure(m)
xs = Zygote.Buffer([])
fmap(m) do x
x isa Zeros && return x # We don't want to include any Zeros as we allow everything returned to be changed
x isa AbstractArray && push!(xs, x)
return x
end
Expand Down
58 changes: 42 additions & 16 deletions src/zeros.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import Base: +, -, *, reshape, size
import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle
import Zygote: unbroadcast

"""
Zeros()
Expand Down Expand Up @@ -54,6 +55,8 @@ Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,

Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))

# Or else they'll turn into a ReshapedArray and all the stuff below is circumvented
Base.reshape(xs::Zeros{T}, dims::Vararg{Union{Colon, Int64},N}) where {T, N} = Zeros(T, Base._reshape_uncolon(xs, dims)...)
@adjoint reshape(xs::Zeros{T}, dims...) where T =
reshape(xs, dims...), _ -> nothing

Expand All @@ -65,42 +68,65 @@ for f in (:+, :-)
end
end

+(a::Zeros, b::AbstractArray) = b + a
-(a::Zeros, b::AbstractArray) = -b + a
# This is a bit of a whack-a-mole with avoiding ambiguity while still making sure we capture all signatures...
@adjoint +(a::AbstractArray{<:Number}, b::Zeros) = a + b, ā -> (ā, nothing)
@adjoint +(a::Zeros, b::AbstractArray{<:Number}) = b + a, b̄ -> (nothing, b̄)
@adjoint +(a::Zeros, b::Zeros) = b + a, _ -> (nothing, nothing)

@adjoint -(a::AbstractArray{<:Number}, b::Zeros) = a - b, ā -> (ā, nothing)
@adjoint -(a::Zeros, b::AbstractArray{<:Number}) = -b + a, b̄ -> (nothing, -b̄)
@adjoint -(a::Zeros{<:Number}, b::Zeros{<:Number}) = a - b, _ -> (nothing, nothing)


Base.copy(xs::Zeros{T,N}) where {T,N} = xs

Base.copyto!(dest::AbstractArray, ::Zeros) = copyto!(dest, zeros(size(dest)))
Base.copyto!(dest::Zeros, ::AbstractArray) = dest
Base.copyto!(dest::Zeros, ::Zeros) = dest

# Define broadcasting behaviour
for op in (:+, :-)
@eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros)
@eval function broadcasted(::typeof($op), a::AbstractArray{<:Number}, b::Zeros)
bs = Broadcast.broadcast_shape(size(a), size(b))
size(a) == bs && return a
sz = similar(a, bs)
sz .= a
return sz
end
end

broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a)
broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a)
@adjoint broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Zeros) = broadcasted(+, a, b), ā -> (nothing, unbroadcast(a, ā), nothing)
@adjoint broadcasted(::typeof(+), a::Zeros, b::AbstractArray{<:Number}) = broadcasted(+, b, a), b̄ -> (nothing, nothing, unbroadcast(b, b̄))

function broadcasted(::typeof(*), a::AbstractArray, b::Zeros)
Zeros(Broadcast.broadcast_shape(size(a), size(b))...)
@adjoint broadcasted(::typeof(-), a::Zeros, b::AbstractArray{<:Number}) = broadcasted(+, -b, a), b̄ -> (nothing, nothing, -unbroadcast(b, b̄))
@adjoint broadcasted(::typeof(-), a::AbstractArray{<:Number}, b::Zeros) = broadcasted(+, a, b), ā -> (nothing, unbroadcast(a, ā), nothing)

@adjoint function broadcasted(::typeof(*), a::AbstractArray{T}, b::Zeros) where T<:Number
zs = zeros(T, Broadcast.broadcast_shape(size(a), size(b))...)
zs, ā -> (nothing, unbroadcast(a, zs), nothing)
end

broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = broadcasted(*, b, a)
@adjoint function broadcasted(::typeof(*), a::Zeros, b::AbstractArray{T}) where T<:Number
zs = zeros(T, Broadcast.broadcast_shape(size(a), size(b))...)
zs, b̄ -> (nothing, nothing, unbroadcast(b, zs))
end

for op in (:+, :-, :*)
@eval broadcasted(::typeof($op), a::Zeros, b::Zeros) = Zeros(Broadcast.broadcast_shape(size(a), size(b))...)
@eval @adjoint broadcasted(::typeof($op), a::Zeros, b::Zeros) = Zeros(Broadcast.broadcast_shape(size(a), size(b))...), _ -> (nothing, nothing, nothing)
# To avoid ambiguity with 0-size Zeros below
@eval @adjoint broadcasted(::typeof($op), a::Zeros, b::Zeros{T, 0}) where T<:Number = a, _ -> (nothing, nothing, nothing)
@eval @adjoint broadcasted(::typeof($op), a::Zeros{T, 0}, b::Zeros{T, 0}) where T<:Number = a, _ -> (nothing, nothing, nothing)
@eval @adjoint broadcasted(::typeof($op), a::Zeros{T, 0}, b::Zeros) where T<:Number = a, _ -> (nothing, nothing, nothing)
end

# Some opportunities to avoid scalar indexing, intermediaries
# Since it replicates a little of what we expect Base to do,
# it should be possible to remove in the future, but for now,
# these help with performance.
broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a
broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b
broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a
broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b
broadcasted(::typeof(*), a::AbstractArray, b::Zeros{T,0}) where T = zero(a)
broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray) where T = zero(b)
broadcasted(::typeof(/), a::Zeros{T,0}, b::AbstractArray) where T = zero(b)
@adjoint broadcasted(::typeof(+), a::AbstractArray{T}, b::Zeros{<:Number,0}) where T<:Number = a, ā -> (nothing, unbroadcast(a, ā), nothing)
@adjoint broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray{<:Number}) where T<:Number = b, b̄ -> (nothing, nothing, unbroadcast(b, b̄))
@adjoint broadcasted(::typeof(-), a::AbstractArray{<:Number}, b::Zeros{T,0}) where T<:Number = a, ā -> (nothing, unbroadcast(a, ā), nothing)
@adjoint broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray{<:Number}) where T<:Number = -b, b̄ -> (nothing, nothing, -unbroadcast(b, b̄))
@adjoint broadcasted(::typeof(*), a::AbstractArray{<:Number}, b::Zeros{T,0}) where T<:Number = zero(a), ā -> (nothing, unbroadcast(a, Zeros(eltype(a), size(a)...)), nothing)
@adjoint broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray{<:Number}) where T<:Number = zero(b), b̄ -> (nothing, nothing, unbroadcast(b, Zeros(eltype(b), size(b)...)))
@adjoint broadcasted(::typeof(/), a::Zeros{T,0}, b::AbstractArray{<:Number}) where T<:Number = zero(b), b̄ -> (nothing, nothing, unbroadcast(b, Zeros(eltype(b), size(b)...)))
29 changes: 28 additions & 1 deletion test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ end

# Repeats from Conv, CrossCor

# Just to give testset in gradtest meaningful labels
ConvNoBias(args...) = Conv(args...; bias=Flux.Zeros())
ConvTransposeNoBias(args...) = ConvTranspose(args...; bias=Flux.Zeros())
CrossCorNoBias(args...) = CrossCor(args...; bias=Flux.Zeros())
DepthwiseConvNoBias(args...) = DepthwiseConv(args...;bias=Flux.Zeros())
r = rand(Float32, 28, 28, 1, 1)
conv_layers = [Conv, ConvTranspose, CrossCor, DepthwiseConv]
conv_layers = [Conv, ConvNoBias, ConvTranspose, ConvTransposeNoBias, CrossCor, CrossCorNoBias, DepthwiseConv, DepthwiseConvNoBias]
gradtest("Conv", conv_layers, r, (2,2), 1=>3)

pooling_layers = [MaxPool, MeanPool]
Expand Down Expand Up @@ -95,3 +100,25 @@ end
stateless_gradtest_broadcasted(layer, x, y)
end
end

@testset "Zeros mapped for $cl" for cl in (Conv, ConvTranspose, CrossCor, DepthwiseConv)
l = cl((2,2), 1=>3, bias = Flux.Zeros()) |> gpu
ip = zeros(Float32, 28,28,1,1) |> gpu
if cl in BROKEN_LAYERS
@test_broken sum(l(ip)) ≈ 0.f0
@test_broken gradient(() -> sum(l(ip)), Flux.params(l)) isa Flux.Zygote.Grads
else
@test sum(l(ip)) ≈ 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test gs[l.bias] === nothing
end
end

@testset "Dense with Zeros bias" begin
l = Dense(ones(Float32, 4,3), Flux.Zeros()) |> gpu
ip = zeros(Float32, 3, 7) |> gpu

@test sum(l(ip)) ≈ 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test gs[l.b] === nothing
end
1 change: 1 addition & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import Flux: activations
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
@test Dense(10, 2, identity, initW = ones, initb = Zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
end

@testset "Diagonal" begin
Expand Down
12 changes: 7 additions & 5 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ end
op = bias(ip)
@test sum(op) == prod(size(op))

bias = Conv((2,2), 1=>3, bias = Flux.Zeros())
op = bias(ip)
@test sum(op) ≈ 0.f0
gs = gradient(() -> sum(bias(ip)), Flux.params(bias))
@test gs[bias.bias] == nothing
@testset "Zeros mapped through $lmap" for lmap in (identity, cpu, f32)
bias = Conv((2,2), 1=>3, bias = Flux.Zeros()) |> lmap
op = bias(ip)
@test sum(op) ≈ 0.f0
gs = gradient(() -> sum(bias(ip)), Flux.params(bias))
@test gs[bias.bias] === nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to test this with CUDA as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you are referring to the test that gradient w.r.t Zeros is nothing, right? Will add it right away!

end

# Train w/o bias and make sure no convergence happens
# when only bias can be converged
Expand Down
5 changes: 3 additions & 2 deletions test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ using Random
Nesterov(), RMSProp(), Momentum()]
Random.seed!(42)
w′ = randn(10, 10)
loss(x) = Flux.Losses.mse(w*x, w′*x)
b = Flux.Zeros()
loss(x) = Flux.Losses.mse(w*x, w′*x .+ b)
for t = 1: 10^5
θ = Params([w′])
θ = Params([w′, b])
x = rand(10)
θ̄ = gradient(() -> loss(x), θ)
Optimise.update!(opt, θ, θ̄)
Expand Down
161 changes: 160 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, stack, unstack
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, stack, unstack, Zeros
using StatsBase: var, std
using Random
using Test
Expand Down Expand Up @@ -129,10 +129,169 @@ end
@test eltype(f32(f64(m))[1].W) == Float32
end

@testset "Zeros" begin
m = Dense(randn(2,3), Zeros())
@test f64(m).b === m.b === Zeros()
@test f32(m).b === m.b === Zeros()

@testset "Gradients for broadcasted $op with sizes $s" for op in (+,-,*), s in ((1,), (2,3))
o = ones(s)
z = zeros(s)
Zs = Zeros(s...)
Z0 = Zeros()

@testset "Explicit" begin
gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...)
g = gfun(o, z)
@test gfun(o, Zs) == (g[1], nothing)
@test gfun(o, Z0) == (g[1], nothing)

g = gfun(z, o)
@test gfun(Zs, o) == (nothing, g[2])
@test gfun(Z0, o) == (nothing, g[2])

@test gfun(Zs, Zs) == gfun(Z0, Z0) == gfun(Zs, Z0) == gfun(Z0, Zs) == (nothing, nothing)
end

@testset "Implicit" begin
gfun(args...) = gradient(() -> sum(op.(args...)), params(collect(args)))
g = gfun(o, z)

gres = gfun(o, Zs)
@test gres[o] == g[o]
@test gres[Zs] === nothing

gres = gfun(o, Z0)
@test gres[o] == g[o]
@test gres[Z0] === nothing

g = gfun(z, o)

gres = gfun(Zs, o)
@test gres[o] == g[o]
@test gres[Zs] === nothing


gres = gfun(Z0, o)
@test gres[o] == g[o]
@test gres[Z0] === nothing

gfunc(args...) = gfun(args...).grads |> values |> Tuple
@test gfunc(Zs, Zs) == gfunc(Z0, Z0) == (nothing,)
@test gfunc(Zs, Z0) == gfunc(Z0, Zs) == (nothing, nothing)
end
end

@testset "Gradients for broadcasted / with sizes $s" for s in ((1,), (2,3))
o = ones(s)
z = zeros(s)
Z = Zeros() # Only defined for 0-dim

@testset "Explicit" begin
gfun(args...) = gradient((x, y) -> sum(x ./ y), args...)
g = gfun(z, o)
@test gfun(Z, o) == (nothing, g[2])
end

@testset "Implicit" begin
gfun(x,y) = gradient(() -> sum(x ./ y), params([x,y]))

g = gfun(z, o)
gres = gfun(Z, o)
@test gres[o] == g[o]
@test gres[Z] === nothing
end
end

@testset "Gradients for $op with sizes $s" for op in (+,-), s in (tuple(), (1,), (2,3))
o = ones(s)
z = zeros(s)
Z = Zeros(s...)


@testset "Explicit" begin
gfun(args...) = gradient((x, y) -> sum(op(x,y)), args...)

g = gfun(o, z)
@test gfun(o, Z) == (g[1], nothing)

g = gfun(z, o)
@test gfun(Z, o) == (nothing, g[2])

@test gfun(Z, Z) == (nothing, nothing)
end

@testset "Implicit" begin
gfun(args...) = gradient(() -> sum(op(args...)), params(collect(args)))
g = gfun(o, z)
gres = gfun(o, Z)
@test gres[o] == g[o]
@test gres[Z] === nothing

g = gfun(z, o)
gres = gfun(Z, o)
@test gres[o] == g[o]
@test gres[Z] === nothing

gfunc(args...) = gfun(args...).grads |> values |> Tuple
@test gfunc(Z, Z) == (nothing,)
end
end
end

@testset "Stacking" begin
stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
@test unstack(stacked_array, 2) == unstacked_array
@test stack(unstacked_array, 2) == stacked_array
@test stack(unstack(stacked_array, 1), 1) == stacked_array
end

@testset "Param remapping" begin
ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...)
dl(nin, nout, bias) = Dense(ls(nin, nout), bias(nout))
dm(bias) = Chain(
dl(3, 5, bias),
dl(5, 4, bias),
dl(4, 3, bias)
)

testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
@test l1.W == l2.W
@test l1.b == l2.b
@test typeof(l1.b) === typeof(l2.b)
end

ZerosNoShape(::Integer) = Zeros() # Just to get a readable name in testsets
@testset "loadparams!" begin
import Flux: loadparams!
pararray(m) = mapreduce(l -> collect(params(l).order), vcat, m)
@testset "Bias type $bt" for bt in (zeros, Zeros, ZerosNoShape)
m = dm(bt)
loadparams!(m, params(m))
testdense(m, bt)
end

@testset "$b1 to $b2" for (b1, b2, be) in (
(zeros, ones, ones), # Load ones as bias to a model with zeros as bias -> model gets ones as bias
(ones, Zeros, zeros), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias
(Zeros, ones, Zeros), # Load ones as bias to a model with Zeros as bias-> model bias does not change
(ones, ZerosNoShape, zeros), # Load 0d Zeros as bias to a model with ones as bias -> model get zeros as bias
#(ZerosNoShape, ones, ZerosNoShape), # Does not work as loadmodel! uses params which is backed by a set -> different number of parameters in models
)
m1 = dm(b1)
m2 = dm(b2)
loadparams!(m1, pararray(m2))
testdense(m1, be)
end
end

@testset "destructure" begin
import Flux: destructure
@testset "Bias type $bt" for bt in (zeros, Zeros, ZerosNoShape)
m = dm(bt)
p, re = destructure(m)
testdense(re(p), bt)
end
end
end