Skip to content

Commit

Permalink
split create_bias into two
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Feb 13, 2021
1 parent 3954f04 commit 644ec8c
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,16 +297,20 @@ Return a bias parameter for a layer, based on the value given
to the constructor's keyword `bias=bias`.
* `bias == true` creates a zero vector, of the same type as weights.
* `bias == false` returns `Zeros()`, a special struct which exists to encode the absence of bias.
* `bias == false` returns `Zeros()`, a special struct which exists only to encode the absence of bias.
* `bias::AbstractArray` uses the array provided, provided it has the correct size and eltype. If the type is wrong, it will be converted.
"""
function create_bias(weights::AbstractArray{T}, bias::Union{Bool, AbstractArray}, dims::Integer...) where {T}
bias===true && return fill!(similar(weights, dims...), 0)
bias===false && return Zeros()
size(bias) == dims || throw(DimensionMismatch("expected bias of size $dims, but got $(size(bias))"))
eltype(bias) == T && return bias
@warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims)
return T.(bias)
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
bias ? fill!(similar(weights, dims...), 0) : Zeros()
end
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
if eltype(bias) == eltype(weights)
return bias
else
@warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims)
return broadcast(eltype(weights), bias)
end
end

"""
Expand Down

0 comments on commit 644ec8c

Please sign in to comment.