Skip to content

Commit

Permalink
Merge branch 'master' into dg/den
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi authored Mar 31, 2021
2 parents 0a35664 + 28f34d1 commit 35d737b
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ to the constructor's keyword `bias=bias`.
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
bias ? fill!(similar(weights, dims...), 0) : Zeros()
end

function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
bias
Expand Down
9 changes: 5 additions & 4 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ import Flux: activations
@test Dense(rand(100,10), false, tanh).σ == tanh
@test Dense(rand(100,10), rand(100)).σ == identity
@test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type
# @test Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
@test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match

@test_skip Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
@test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}

# @test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
# @test Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}

@test_throws MethodError Dense(10, 10.5)
@test_throws MethodError Dense(10, 10.5, tanh)
Expand Down Expand Up @@ -167,7 +168,7 @@ import Flux: activations
@test size(b3(rand(4), rand(5))) == (3,)

b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros)
# @test b4.bias isa Vector{Float32}
@test_skip b4.bias isa Vector{Float32}

@test_throws ArgumentError Flux.Bilinear(rand(3)) # expects a 3-array
@test_throws ArgumentError Flux.Bilinear(rand(3,4), false, tanh)
Expand Down
2 changes: 1 addition & 1 deletion test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ end
@test fun(rand(2,3,4,5), false).bias isa Flux.Zeros
if fun == Conv
@test fun(rand(2,3,4,5,6), rand(6)).bias isa Vector{Float64}
# @test fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64}
@test_skip fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64}
elseif fun == DepthwiseConv
@test fun(rand(2,3,4,5,6), rand(30)).bias isa Vector{Float64}
end
Expand Down
2 changes: 1 addition & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ end
testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
@test l1.W == l2.W
@test l1.b == l2.b
@test typeof(l1.b) === typeof(l2.b)
@test_skip typeof(l1.b) === typeof(l2.b)
end

@testset "loadparams!" begin
Expand Down

0 comments on commit 35d737b

Please sign in to comment.