Skip to content

Commit

Permalink
Merge #1557
Browse files Browse the repository at this point in the history
1557: Fix #1556 r=DhairyaLGandhi a=DhairyaLGandhi

Reverts some Dense changes since those caused issues downstream. cc @ChrisRackauckas @darsnack @CarloLucibello 

Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
  • Loading branch information
bors[bot] and DhairyaLGandhi authored Mar 31, 2021
2 parents 28f34d1 + 35d737b commit c563189
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 34 deletions.
46 changes: 21 additions & 25 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ extraChain(::Tuple{}, x) = ()


"""
Dense(in, out, σ=identity; bias=true, init=glorot_uniform)
Dense(in, out, σ = identity; bias = true, init = glorot_uniform)
Dense(W::AbstractMatrix, [bias, σ])
Create a traditional `Dense` layer, whose forward pass is given by:
Expand All @@ -81,7 +81,7 @@ as an `in × N` matrix, or any array with `size(x,1) == in`.
The out `y` will be a vector of length `out`, or a batch with
`size(y) == (out, size(x)[2:end]...)`
Keyword `bias=false` will switch off trainable bias for the layer.
Keyword `bias = false` will switch off trainable bias for the layer.
The initialisation of the weight matrix is `W = init(out, in)`, calling the function
given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
Expand Down Expand Up @@ -109,45 +109,41 @@ julia> Flux.params(d1) # no trainable bias
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
```
"""
struct Dense{F, M<:AbstractMatrix, B}
weight::M
bias::B
struct Dense{F,S<:AbstractArray,T}
weight::S
bias::T
σ::F
function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
b = create_bias(W, bias, size(W,1))
new{F,M,typeof(b)}(W, b, σ)
end
end

function Dense(in::Integer, out::Integer, σ = identity;
initW = nothing, initb = nothing,
init = glorot_uniform, bias=true)
Dense(W, b) = Dense(W, b, identity)

W = if initW !== nothing
Base.depwarn("keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)", :Dense)
initW(out, in)
else
init(out, in)
Dense(W::AbstractArray, b::Bool = true, σ = identity) =
Dense(W, create_bias(W, b, size(W,1)), σ)

function Dense(in::Integer, out::Integer, σ = identity; initW = nothing,
init = glorot_uniform, initb = nothing, bias::Bool = true)
if initW !== nothing
Base.depwarn("initW is deprecated, please use the `init` keyword instead", :Dense)
init = initW
end

b = if bias === true && initb !== nothing
Base.depwarn("keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)", :Dense)
initb(out)
if initb !== nothing
Base.depwarn("initb is deprecated, please use the array based constructors instead", :Dense)
initb = initb
else
bias
initb = zeros
end

return Dense(W, b, σ)
Dense(init(out, in), bias ? initb(out) : Zeros(), σ)
end

@functor Dense

function (a::Dense)(x::AbstractVecOrMat)
W, b, σ = a.weight, a.bias, a.σ
return σ.(W*x .+ b)
σ.(W * x .+ b)
end

(a::Dense)(x::AbstractArray) =
(a::Dense)(x) =
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

function Base.show(io::IO, l::Dense)
Expand Down
13 changes: 7 additions & 6 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,17 @@ import Flux: activations
@test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type
@test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match

@test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
@test_skip Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
@test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}


@test_throws MethodError Dense(10, 10.5)
@test_throws MethodError Dense(10, 10.5, tanh)
@test_throws DimensionMismatch Dense(3,4; bias=rand(5))
@test_throws DimensionMismatch Dense(rand(4,3), rand(5))
@test_throws MethodError Dense(rand(5))
@test_throws MethodError Dense(rand(5), rand(5))
@test_throws MethodError Dense(rand(5), rand(5), tanh)
# @test_throws DimensionMismatch Dense(3,4; bias=rand(5))
# @test_throws DimensionMismatch Dense(rand(4,3), rand(5))
# @test_throws MethodError Dense(rand(5))
# @test_throws MethodError Dense(rand(5), rand(5))
# @test_throws MethodError Dense(rand(5), rand(5), tanh)
end
@testset "dimensions" begin
@test length(Dense(10, 5)(randn(10))) == 5
Expand Down
6 changes: 3 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ end
end

@testset "$b1 to $b2" for (b1, b2, be) in (
(Flux.zeros, ones, ones), # Load ones as bias to a model with zeros as bias -> model gets ones as bias
(ones, nobias, Flux.zeros), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias
(nobias, ones, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change
(Flux.zeros, Flux.ones, Flux.ones), # Load ones as bias to a model with zeros as bias -> model gets ones as bias
(Flux.ones, nobias, Flux.zeros), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias
(nobias, Flux.ones, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change
)
m1 = dm(b1)
m2 = dm(b2)
Expand Down

0 comments on commit c563189

Please sign in to comment.