@@ -69,100 +69,134 @@ extraChain(::Tuple{}, x) = ()
69
69
70
70
71
71
"""
72
- Dense(in, out, σ=identity; initW=glorot_uniform, initb=zeros, bias=true )
73
- Dense(W, b , σ=identity )
72
+ Dense(in, out, σ=identity; bias=true, init=glorot_uniform )
73
+ Dense(W::AbstractMatrix, [bias , σ] )
74
74
75
- Create a traditional `Dense` layer with in×out weight matrix `W` and
76
- bias vector `b` of length `out`. The forward pass is given by:
75
+ Create a traditional `Dense` layer, whose forward pass is given by:
77
76
78
- y = σ.(W * x .+ b )
77
+ y = σ.(W * x .+ bias )
79
78
80
- The input `x` must be a vector of length `in`, a batch of vectors represented
81
- as an `in × N` matrix, or a higher order tensor where all dimensions
82
- after the first one will be treated as batch dimensions.
79
+ The input `x` should be a vector of length `in`, or batch of vectors represented
80
+ as an `in × N` matrix, or any array with `size(x,1) == in`.
81
+ The out `y` will be a vector of length `out`, or a batch with
82
+ `size(y) == (out, size(x)[2:end]...)`
83
83
84
- The out `y` will be a vector of length `out` or a batch whose first
85
- dimension is `out` and the remaining dimensions are the same as in the input.
86
-
87
- Setting `bias` to `false` will switch the bias off for the layer.
88
-
89
- `initW` and `initb` are callables used to initialize weights and biases respectively,
90
- through the calls `initW(out, in)` and `initb(out)`.
84
+ Keyword `bias=false` will switch off trainable bias for the layer.
85
+ The initialisation of the weight matrix is `W = init(out, in)`, calling the function
86
+ given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
87
+ The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
91
88
92
89
# Examples
93
-
94
- ```julia-repl
90
+ ```jldoctest
95
91
julia> d = Dense(5, 2)
96
92
Dense(5, 2)
97
93
98
- julia> d(rand(Float32, 5))
99
- 2-element Array{Float32,1}:
100
- -0.16210233
101
- 0.123119034
94
+ julia> d(rand(Float32, 5, 64)) |> size
95
+ (2, 64)
102
96
103
- julia> d = Dense(5, 2; bias=false)
104
- Dense(5, 2)
97
+ julia> d(rand(Float32, 5, 1, 1, 64)) |> size # treated as three batch dimensions
98
+ (2, 1, 1, 64)
99
+
100
+ julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix
101
+ Dense(5, 2, tanh; bias=false)
102
+
103
+ julia> d1(ones(5))
104
+ 2-element Array{Float64,1}:
105
+ 0.9999092042625951
106
+ 0.9999092042625951
107
+
108
+ julia> Flux.params(d1) # no trainable bias
109
+ Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
105
110
```
106
111
"""
107
- struct Dense{F,S <: AbstractArray ,T <: Union{Zeros, AbstractVector} }
108
- W :: S
109
- b :: T
112
+ struct Dense{F, M <: AbstractMatrix , B }
113
+ weight :: M
114
+ bias :: B
110
115
σ:: F
116
+ function Dense (W:: M , bias = true , σ:: F = identity) where {M<: AbstractMatrix , F}
117
+ b = create_bias (W, bias, size (W,1 ))
118
+ new {F,M,typeof(b)} (W, b, σ)
119
+ end
111
120
end
112
121
113
- Dense (W, b) = Dense (W, b, identity)
114
-
115
122
function Dense (in:: Integer , out:: Integer , σ = identity;
116
- initW = glorot_uniform, initb = zeros, bias= true )
117
- return Dense (initW (out, in), create_bias (bias, initb, out), σ)
123
+ initW = nothing , initb = nothing ,
124
+ init = glorot_uniform, bias= true )
125
+
126
+ W = if initW != = nothing
127
+ Base. depwarn (" keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)" , :Dense )
128
+ initW (out, in)
129
+ else
130
+ init (out, in)
131
+ end
132
+
133
+ b = if bias === true && initb != = nothing
134
+ Base. depwarn (" keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)" , :Dense )
135
+ initb (out)
136
+ else
137
+ bias
138
+ end
139
+
140
+ return Dense (W, b, σ)
118
141
end
119
142
120
143
@functor Dense
121
144
122
145
function (a:: Dense )(x:: AbstractArray )
123
- W, b, σ = a. W , a. b , a. σ
146
+ W, b, σ = a. weight , a. bias , a. σ
124
147
sz = size (x)
125
- x = reshape (x, sz[1 ], :) # reshape to handle dims > 1 as batch dimensions
126
- x = σ .(W* x .+ b)
127
- return reshape (x , :, sz[2 : end ]. .. )
148
+ y = reshape (x, sz[1 ], :) # reshape to handle dims > 1 as batch dimensions
149
+ z = σ .(W* y .+ b)
150
+ return reshape (z , :, sz[2 : end ]. .. )
128
151
end
129
152
130
153
function Base. show (io:: IO , l:: Dense )
131
- print (io, " Dense(" , size (l. W , 2 ), " , " , size (l. W , 1 ))
154
+ print (io, " Dense(" , size (l. weight , 2 ), " , " , size (l. weight , 1 ))
132
155
l. σ == identity || print (io, " , " , l. σ)
156
+ l. bias == Zeros () && print (io, " ; bias=false" )
133
157
print (io, " )" )
134
158
end
135
159
136
160
"""
137
161
Diagonal(α, β)
138
- Diagonal(sz ::Integer...; initα=ones, initβ=zeros )
162
+ Diagonal(size ::Integer...)
139
163
140
- Create an element-wise linear layer with learnable
141
- arrays `α` and `β` of size `sz`. The layer performs
164
+ Create an element-wise linear layer, which performs
142
165
143
166
y = α .* x .+ β
144
167
145
- The input `x` must have size broadcast-compatible with `α` and `β`.
146
- The parameters will be created with the calls
147
- `α = initα(sz)` and `β = initβ(sz)`.
168
+ The learnable arrays are initialised `α = ones(Float32, size)` and
169
+ `β = zeros(Float32, size)`.
170
+
171
+ Used by [`LayerNorm`](@ref).
148
172
"""
149
173
struct Diagonal{T}
150
174
α:: T
151
175
β:: T
152
176
end
153
177
154
- function Diagonal (sz:: Integer... ;
155
- initα = i -> ones (Float32, i),
156
- initβ = i -> zeros (Float32, i))
157
- Diagonal (initα (sz), initβ (sz))
178
+ function Diagonal (sz:: Integer... ; initα = nothing , initβ = nothing )
179
+ α = if initα != = nothing
180
+ Base. depwarn (" keyword initα is deprecated, please simply supply the desired vectors" , :Diagonal )
181
+ initα (sz... )
182
+ else
183
+ ones (sz... )
184
+ end
185
+ β = if initβ != = nothing
186
+ Base. depwarn (" keyword initβ is deprecated, please simply supply the desired vectors" , :Diagonal )
187
+ initβ (sz... )
188
+ else
189
+ zeros (sz... )
190
+ end
191
+ Diagonal (α, β)
158
192
end
159
193
160
194
@functor Diagonal
161
195
162
196
(a:: Diagonal )(x) = a. α .* x .+ a. β
163
197
164
198
function Base. show (io:: IO , l:: Diagonal )
165
- print (io, " Diagonal(" , size (l. α), " )" )
199
+ print (io, " Diagonal(" , join ( size (l. α), " , " ), " )" )
166
200
end
167
201
168
202
"""
@@ -249,55 +283,71 @@ function Base.show(io::IO, b::SkipConnection)
249
283
end
250
284
251
285
"""
252
- Bilinear(in1, in2, out)
286
+ Bilinear(in1, in2, out, σ=identity; bias=true, init=glorot_uniform)
287
+ Bilinear(W::AbstractArray, [bias, σ])
253
288
254
289
Creates a Bilinear layer, which operates on two inputs at the same time.
255
- It has parameters `W` and `b`, and its output given vectors `x`, `y` is of the form
290
+ Its output, given vectors `x` & `y`, is another vector `z` with,
291
+ for all `i ∈ 1:out`:
256
292
257
- z[i] = σ. (x' * W[i,:,:] * y .+ b [i])
293
+ z[i] = σ(x' * W[i,:,:] * y + bias [i])
258
294
259
295
If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form,
260
- given that `B` is a Bilinear layer of appropriate size .
296
+ with `B` a Bilinear layer.
261
297
262
298
If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)`
263
299
The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`,
264
300
which is accepted as the input to a `Chain`.
265
301
266
- ```julia
267
- # using Bilinear to generate interactions, on one input
268
- x = randn(Float32, 11, 7)
269
- B = Bilinear(11, 11, 3)
270
- size(B(x)) == (3, 7)
271
-
272
- # using Bilinear on two data streams at once, as a tuple
273
- x = randn(Float32, 10, 9)
274
- y = randn(Float32, 2, 9)
275
- m = Chain(Bilinear(10, 2, 3), Dense(3, 1))
276
- size(m((x, y))) == (1, 9)
277
-
278
- # using Bilinear as the recombinator in a SkipConnection
279
- x = randn(Float32, 10, 9)
280
- sc = SkipConnection(Dense(10, 10), Bilinear(10, 10, 5))
281
- size(sc(x)) == (5, 9)
302
+ The initialisation works as for [`Dense`](@ref) layer, with `W = init(out, in1, in2)`.
303
+ By default the bias vector is `zeros(Float32, out)`, option `bias=false` will switch off
304
+ trainable bias. Either of these may be provided explicitly.
305
+
306
+ # Examples
307
+
308
+ ```jldoctest
309
+ julia> x, y = randn(Float32, 5, 32), randn(Float32, 5, 32);
310
+
311
+ julia> B = Flux.Bilinear(5, 5, 7);
312
+
313
+ julia> B(x) |> size # interactions based on one input
314
+ (7, 32)
315
+
316
+ julia> B(x,y) == B((x,y)) # two inputs, may be given as a tuple
317
+ true
318
+
319
+ julia> sc = SkipConnection(
320
+ Chain(Dense(5, 20, tanh), Dense(20, 9, tanh)),
321
+ Flux.Bilinear(9, 5, 3, bias=false),
322
+ ); # used as the recombinator, with skip as the second input
323
+
324
+ julia> sc(x) |> size
325
+ (3, 32)
326
+
327
+ julia> Flux.Bilinear(rand(4,8,16), false, tanh) # first dim of weight is the output
328
+ Bilinear(8, 16, 4, tanh, bias=false)
282
329
```
283
330
"""
284
- struct Bilinear{A,B,S}
285
- W:: A
286
- b:: B
287
- σ:: S
331
+ struct Bilinear{F,A,B}
332
+ weight:: A
333
+ bias:: B
334
+ σ:: F
335
+ function Bilinear (W:: A , bias = true , σ:: F = identity) where {A<: AbstractArray , F}
336
+ ndims (A) == 3 || throw (ArgumentError (" expected a 3-array of weights" ))
337
+ b = create_bias (W, bias, size (W,1 ))
338
+ new {F,A,typeof(b)} (W, b, σ)
339
+ end
288
340
end
289
341
290
342
@functor Bilinear
291
343
292
- Bilinear (W, b) = Bilinear (W, b, identity)
293
-
294
344
function Bilinear (in1:: Integer , in2:: Integer , out:: Integer , σ = identity;
295
- initW = glorot_uniform, initb = zeros )
296
- return Bilinear (initW (out, in1, in2), initb (out) , σ)
345
+ init = glorot_uniform, bias = true )
346
+ Bilinear (init (out, in1, in2), bias , σ)
297
347
end
298
348
299
349
function (a:: Bilinear )(x:: AbstractMatrix , y:: AbstractMatrix )
300
- W, b, σ = a. W , a. b , a. σ
350
+ W, b, σ = a. weight , a. bias , a. σ
301
351
302
352
d_z, d_x, d_y = size (W)
303
353
d_x == size (x,1 ) && d_y == size (y,1 ) || throw (DimensionMismatch (" number of rows in data must match W" ))
@@ -319,13 +369,14 @@ end
319
369
(a:: Bilinear )(x:: NTuple{2, AbstractArray} ) = a (x[1 ], x[2 ])
320
370
321
371
function Base. show (io:: IO , l:: Bilinear )
322
- print (io, " Bilinear(" , size (l. W , 2 ), " , " , size (l. W , 3 ), " , " , size (l. W , 1 ))
372
+ print (io, " Bilinear(" , size (l. weight , 2 ), " , " , size (l. weight , 3 ), " , " , size (l. weight , 1 ))
323
373
l. σ == identity || print (io, " , " , l. σ)
374
+ l. bias == Flux. Zeros () && print (io, " , bias=false" )
324
375
print (io, " )" )
325
376
end
326
377
327
378
"""
328
- Parallel(connection, layers...)
379
+ Parallel(connection, layers...)
329
380
330
381
Create a 'Parallel' layer that passes an input array to each path in
331
382
`layers`, reducing the output with `connection`.
0 commit comments