Skip to content

Commit 69e2198

Browse files
bors[bot]Michael Abbott
andauthored
Merge #1440
1440: `Dense` keyword handling, and docstring r=DhairyaLGandhi a=mcabbott Closes #1422, by killing the `initW` keyword, in favour of `init` as used by the Conv layers. Also fixes "in×out weight matrix" which was incorrect. And makes `Dense(rand(2,3), bias)` work like `Dense(3,2; bias)`, which again is like the Conv layers. Edit -- also closes #1421 now, ensuring that the bias vectors of both Conv and Dense layers match the eltype of the weights. ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [x] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Michael Abbott <me@escbook>
2 parents 95ac3b1 + ae879cc commit 69e2198

File tree

11 files changed

+263
-128
lines changed

11 files changed

+263
-128
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* Added [Focal Loss function](https://github.com/FluxML/Flux.jl/pull/1489) to Losses module
77
* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405).
88
* Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394).
9+
* The keyword `initW` is of Dense layers is now `init`, to agree with convolutional layers.
910
* Excise datasets in favour of other providers in the julia ecosystem.
1011
* Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained.
1112
* Add [CTC loss function](https://github.com/FluxML/Flux.jl/pull/1287) to Losses module

docs/src/models/layers.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ These core layers form the foundation of almost all neural networks.
55
```@docs
66
Chain
77
Dense
8-
Flux.Diagonal
98
```
109

1110
## Convolution and Pooling Layers
@@ -57,7 +56,8 @@ But in contrast to the layers described in the other sections are not readily gr
5756
Maxout
5857
SkipConnection
5958
Parallel
60-
Bilinear
59+
Flux.Bilinear
60+
Flux.Diagonal
6161
```
6262

6363
## Normalisation & Regularisation

src/deprecations.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,14 @@
77
@deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...)
88
@deprecate ConvTranspose(; weight, bias, activation=identity, kws...) ConvTranspose(weight, bias, activation; kws...)
99
@deprecate DepthwiseConv(; weight, bias, activation=identity, kws...) DepthwiseConv(weight, bias, activation; kws...)
10+
11+
function Base.getproperty(a::Dense, s::Symbol)
12+
if s === :W
13+
Base.depwarn("field name dense.W is deprecated in favour of dense.weight", :Dense)
14+
return getfield(a, :weight)
15+
elseif s === :b
16+
Base.depwarn("field name dense.b is deprecated in favour of dense.bias", :Dense)
17+
return getfield(a, :bias)
18+
end
19+
return getfield(a, s)
20+
end

src/layers/basic.jl

Lines changed: 128 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -69,100 +69,134 @@ extraChain(::Tuple{}, x) = ()
6969

7070

7171
"""
72-
Dense(in, out, σ=identity; initW=glorot_uniform, initb=zeros, bias=true)
73-
Dense(W, b, σ=identity)
72+
Dense(in, out, σ=identity; bias=true, init=glorot_uniform)
73+
Dense(W::AbstractMatrix, [bias, σ])
7474
75-
Create a traditional `Dense` layer with in×out weight matrix `W` and
76-
bias vector `b` of length `out`. The forward pass is given by:
75+
Create a traditional `Dense` layer, whose forward pass is given by:
7776
78-
y = σ.(W * x .+ b)
77+
y = σ.(W * x .+ bias)
7978
80-
The input `x` must be a vector of length `in`, a batch of vectors represented
81-
as an `in × N` matrix, or a higher order tensor where all dimensions
82-
after the first one will be treated as batch dimensions.
79+
The input `x` should be a vector of length `in`, or batch of vectors represented
80+
as an `in × N` matrix, or any array with `size(x,1) == in`.
81+
The out `y` will be a vector of length `out`, or a batch with
82+
`size(y) == (out, size(x)[2:end]...)`
8383
84-
The out `y` will be a vector of length `out` or a batch whose first
85-
dimension is `out` and the remaining dimensions are the same as in the input.
86-
87-
Setting `bias` to `false` will switch the bias off for the layer.
88-
89-
`initW` and `initb` are callables used to initialize weights and biases respectively,
90-
through the calls `initW(out, in)` and `initb(out)`.
84+
Keyword `bias=false` will switch off trainable bias for the layer.
85+
The initialisation of the weight matrix is `W = init(out, in)`, calling the function
86+
given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
87+
The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
9188
9289
# Examples
93-
94-
```julia-repl
90+
```jldoctest
9591
julia> d = Dense(5, 2)
9692
Dense(5, 2)
9793
98-
julia> d(rand(Float32, 5))
99-
2-element Array{Float32,1}:
100-
-0.16210233
101-
0.123119034
94+
julia> d(rand(Float32, 5, 64)) |> size
95+
(2, 64)
10296
103-
julia> d = Dense(5, 2; bias=false)
104-
Dense(5, 2)
97+
julia> d(rand(Float32, 5, 1, 1, 64)) |> size # treated as three batch dimensions
98+
(2, 1, 1, 64)
99+
100+
julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix
101+
Dense(5, 2, tanh; bias=false)
102+
103+
julia> d1(ones(5))
104+
2-element Array{Float64,1}:
105+
0.9999092042625951
106+
0.9999092042625951
107+
108+
julia> Flux.params(d1) # no trainable bias
109+
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
105110
```
106111
"""
107-
struct Dense{F,S<:AbstractArray,T<:Union{Zeros, AbstractVector}}
108-
W::S
109-
b::T
112+
struct Dense{F, M<:AbstractMatrix, B}
113+
weight::M
114+
bias::B
110115
σ::F
116+
function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
117+
b = create_bias(W, bias, size(W,1))
118+
new{F,M,typeof(b)}(W, b, σ)
119+
end
111120
end
112121

113-
Dense(W, b) = Dense(W, b, identity)
114-
115122
function Dense(in::Integer, out::Integer, σ = identity;
116-
initW = glorot_uniform, initb = zeros, bias=true)
117-
return Dense(initW(out, in), create_bias(bias, initb, out), σ)
123+
initW = nothing, initb = nothing,
124+
init = glorot_uniform, bias=true)
125+
126+
W = if initW !== nothing
127+
Base.depwarn("keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)", :Dense)
128+
initW(out, in)
129+
else
130+
init(out, in)
131+
end
132+
133+
b = if bias === true && initb !== nothing
134+
Base.depwarn("keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)", :Dense)
135+
initb(out)
136+
else
137+
bias
138+
end
139+
140+
return Dense(W, b, σ)
118141
end
119142

120143
@functor Dense
121144

122145
function (a::Dense)(x::AbstractArray)
123-
W, b, σ = a.W, a.b, a.σ
146+
W, b, σ = a.weight, a.bias, a.σ
124147
sz = size(x)
125-
x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions
126-
x = σ.(W*x .+ b)
127-
return reshape(x, :, sz[2:end]...)
148+
y = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions
149+
z = σ.(W*y .+ b)
150+
return reshape(z, :, sz[2:end]...)
128151
end
129152

130153
function Base.show(io::IO, l::Dense)
131-
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
154+
print(io, "Dense(", size(l.weight, 2), ", ", size(l.weight, 1))
132155
l.σ == identity || print(io, ", ", l.σ)
156+
l.bias == Zeros() && print(io, "; bias=false")
133157
print(io, ")")
134158
end
135159

136160
"""
137161
Diagonal(α, β)
138-
Diagonal(sz::Integer...; initα=ones, initβ=zeros)
162+
Diagonal(size::Integer...)
139163
140-
Create an element-wise linear layer with learnable
141-
arrays `α` and `β` of size `sz`. The layer performs
164+
Create an element-wise linear layer, which performs
142165
143166
y = α .* x .+ β
144167
145-
The input `x` must have size broadcast-compatible with `α` and `β`.
146-
The parameters will be created with the calls
147-
`α = initα(sz)` and `β = initβ(sz)`.
168+
The learnable arrays are initialised `α = ones(Float32, size)` and
169+
`β = zeros(Float32, size)`.
170+
171+
Used by [`LayerNorm`](@ref).
148172
"""
149173
struct Diagonal{T}
150174
α::T
151175
β::T
152176
end
153177

154-
function Diagonal(sz::Integer...;
155-
initα = i -> ones(Float32, i),
156-
initβ = i -> zeros(Float32, i))
157-
Diagonal(initα(sz), initβ(sz))
178+
function Diagonal(sz::Integer...; initα = nothing, initβ = nothing)
179+
α = if initα !== nothing
180+
Base.depwarn("keyword initα is deprecated, please simply supply the desired vectors", :Diagonal)
181+
initα(sz...)
182+
else
183+
ones(sz...)
184+
end
185+
β = if initβ !== nothing
186+
Base.depwarn("keyword initβ is deprecated, please simply supply the desired vectors", :Diagonal)
187+
initβ(sz...)
188+
else
189+
zeros(sz...)
190+
end
191+
Diagonal(α, β)
158192
end
159193

160194
@functor Diagonal
161195

162196
(a::Diagonal)(x) = a.α .* x .+ a.β
163197

164198
function Base.show(io::IO, l::Diagonal)
165-
print(io, "Diagonal(", size(l.α), ")")
199+
print(io, "Diagonal(", join(size(l.α), ", "), ")")
166200
end
167201

168202
"""
@@ -249,55 +283,71 @@ function Base.show(io::IO, b::SkipConnection)
249283
end
250284

251285
"""
252-
Bilinear(in1, in2, out)
286+
Bilinear(in1, in2, out, σ=identity; bias=true, init=glorot_uniform)
287+
Bilinear(W::AbstractArray, [bias, σ])
253288
254289
Creates a Bilinear layer, which operates on two inputs at the same time.
255-
It has parameters `W` and `b`, and its output given vectors `x`, `y` is of the form
290+
Its output, given vectors `x` & `y`, is another vector `z` with,
291+
for all `i ∈ 1:out`:
256292
257-
z[i] = σ.(x' * W[i,:,:] * y .+ b[i])
293+
z[i] = σ(x' * W[i,:,:] * y + bias[i])
258294
259295
If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form,
260-
given that `B` is a Bilinear layer of appropriate size.
296+
with `B` a Bilinear layer.
261297
262298
If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)`
263299
The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`,
264300
which is accepted as the input to a `Chain`.
265301
266-
```julia
267-
# using Bilinear to generate interactions, on one input
268-
x = randn(Float32, 11, 7)
269-
B = Bilinear(11, 11, 3)
270-
size(B(x)) == (3, 7)
271-
272-
# using Bilinear on two data streams at once, as a tuple
273-
x = randn(Float32, 10, 9)
274-
y = randn(Float32, 2, 9)
275-
m = Chain(Bilinear(10, 2, 3), Dense(3, 1))
276-
size(m((x, y))) == (1, 9)
277-
278-
# using Bilinear as the recombinator in a SkipConnection
279-
x = randn(Float32, 10, 9)
280-
sc = SkipConnection(Dense(10, 10), Bilinear(10, 10, 5))
281-
size(sc(x)) == (5, 9)
302+
The initialisation works as for [`Dense`](@ref) layer, with `W = init(out, in1, in2)`.
303+
By default the bias vector is `zeros(Float32, out)`, option `bias=false` will switch off
304+
trainable bias. Either of these may be provided explicitly.
305+
306+
# Examples
307+
308+
```jldoctest
309+
julia> x, y = randn(Float32, 5, 32), randn(Float32, 5, 32);
310+
311+
julia> B = Flux.Bilinear(5, 5, 7);
312+
313+
julia> B(x) |> size # interactions based on one input
314+
(7, 32)
315+
316+
julia> B(x,y) == B((x,y)) # two inputs, may be given as a tuple
317+
true
318+
319+
julia> sc = SkipConnection(
320+
Chain(Dense(5, 20, tanh), Dense(20, 9, tanh)),
321+
Flux.Bilinear(9, 5, 3, bias=false),
322+
); # used as the recombinator, with skip as the second input
323+
324+
julia> sc(x) |> size
325+
(3, 32)
326+
327+
julia> Flux.Bilinear(rand(4,8,16), false, tanh) # first dim of weight is the output
328+
Bilinear(8, 16, 4, tanh, bias=false)
282329
```
283330
"""
284-
struct Bilinear{A,B,S}
285-
W::A
286-
b::B
287-
σ::S
331+
struct Bilinear{F,A,B}
332+
weight::A
333+
bias::B
334+
σ::F
335+
function Bilinear(W::A, bias = true, σ::F = identity) where {A<:AbstractArray, F}
336+
ndims(A) == 3 || throw(ArgumentError("expected a 3-array of weights"))
337+
b = create_bias(W, bias, size(W,1))
338+
new{F,A,typeof(b)}(W, b, σ)
339+
end
288340
end
289341

290342
@functor Bilinear
291343

292-
Bilinear(W, b) = Bilinear(W, b, identity)
293-
294344
function Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity;
295-
initW = glorot_uniform, initb = zeros)
296-
return Bilinear(initW(out, in1, in2), initb(out), σ)
345+
init = glorot_uniform, bias = true)
346+
Bilinear(init(out, in1, in2), bias, σ)
297347
end
298348

299349
function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
300-
W, b, σ = a.W, a.b, a.σ
350+
W, b, σ = a.weight, a.bias, a.σ
301351

302352
d_z, d_x, d_y = size(W)
303353
d_x == size(x,1) && d_y == size(y,1) || throw(DimensionMismatch("number of rows in data must match W"))
@@ -319,13 +369,14 @@ end
319369
(a::Bilinear)(x::NTuple{2, AbstractArray}) = a(x[1], x[2])
320370

321371
function Base.show(io::IO, l::Bilinear)
322-
print(io, "Bilinear(", size(l.W, 2), ", ", size(l.W, 3), ", ", size(l.W, 1))
372+
print(io, "Bilinear(", size(l.weight, 2), ", ", size(l.weight, 3), ", ", size(l.weight, 1))
323373
l.σ == identity || print(io, ", ", l.σ)
374+
l.bias == Flux.Zeros() && print(io, ", bias=false")
324375
print(io, ")")
325376
end
326377

327378
"""
328-
Parallel(connection, layers...)
379+
Parallel(connection, layers...)
329380
330381
Create a 'Parallel' layer that passes an input array to each path in
331382
`layers`, reducing the output with `connection`.

0 commit comments

Comments
 (0)