Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal fix of #1556, remove eltype checks #1558

Merged
merged 2 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,14 +388,10 @@ to the constructor's keyword `bias=bias`.
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
bias ? fill!(similar(weights, dims...), 0) : Zeros()
end

function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
if eltype(bias) == eltype(weights)
return bias
else
@warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims)
return broadcast(eltype(weights), bias)
end
bias
end

"""
Expand Down
6 changes: 3 additions & 3 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ import Flux: activations
@test Dense(rand(100,10), false, tanh).σ == tanh
@test Dense(rand(100,10), rand(100)).σ == identity
@test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type
@test Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
@test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match

@test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
@test Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}
@test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}

@test_throws MethodError Dense(10, 10.5)
@test_throws MethodError Dense(10, 10.5, tanh)
Expand Down Expand Up @@ -167,7 +167,7 @@ import Flux: activations
@test size(b3(rand(4), rand(5))) == (3,)

b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros)
@test b4.bias isa Vector{Float32}
@test_skip b4.bias isa Vector{Float32}

@test_throws ArgumentError Flux.Bilinear(rand(3)) # expects a 3-array
@test_throws ArgumentError Flux.Bilinear(rand(3,4), false, tanh)
Expand Down
2 changes: 1 addition & 1 deletion test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ end
@test fun(rand(2,3,4,5), false).bias isa Flux.Zeros
if fun == Conv
@test fun(rand(2,3,4,5,6), rand(6)).bias isa Vector{Float64}
@test fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64}
@test_skip fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64}
elseif fun == DepthwiseConv
@test fun(rand(2,3,4,5,6), rand(30)).bias isa Vector{Float64}
end
Expand Down
2 changes: 1 addition & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ end
testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
@test l1.W == l2.W
@test l1.b == l2.b
@test typeof(l1.b) === typeof(l2.b)
@test_skip typeof(l1.b) === typeof(l2.b)
end

@testset "loadparams!" begin
Expand Down