@@ -371,13 +371,13 @@ end
371
371
Weight Normalization.
372
372
This layer reparametrizes weights (w) of a layer with its decomposition into magnitude (g) and direction (v).
373
373
374
- WeightNorm(layer, weight::Union{Symbol,Int} , dim)
374
+ WeightNorm(layer, weight, dim)
375
375
376
376
``layer`` is the layer being normalized.
377
377
378
- ``weight`` is the parameter to be normalized.
378
+ ``weight`` are the parameters to be normalized.
379
379
380
- ``dim`` is the dimension of normalization.
380
+ ``dim`` are the dimension of normalization.
381
381
Often, its the dimension encoding the output channels.
382
382
383
383
Example:
@@ -390,55 +390,62 @@ wndB = WeightNorm(d, :W, 1:2); #Now we normalize all directions together, keepin
390
390
Link : https://arxiv.org/pdf/1602.07868.pdf
391
391
"""
392
392
393
- struct WeightNormWeight {T,N,I}
393
+ struct WeightNormParam {T,N,I}
394
394
g:: AbstractArray{T,N}
395
395
v:: AbstractArray{T,N}
396
396
dim:: I
397
397
end
398
398
399
- Base. size (w:: WeightNormWeight , i... ) = size (w. v, i... )
400
- Base. size (w:: WeightNormWeight ) = size (w. v)
401
- Base. iterate (w:: WeightNormWeight , i... ) = iterate (w. g .* w. v ./ WN_mag (w. v, w. dim), i... )
402
- Base. getindex (w:: WeightNormWeight , i... ) = getindex (w. g .* w. v ./ WN_mag (w. v, w. dim), i... )
403
- Base. ndims (w:: WeightNormWeight ) = ndims (w. v)
399
+ Base. size (w:: WeightNormParam , i... ) = size (w. v, i... )
400
+ Base. size (w:: WeightNormParam ) = size (w. v)
401
+ Base. iterate (w:: WeightNormParam , i... ) = iterate (w. g .* w. v ./ WN_mag (w. v, w. dim), i... )
402
+ Base. getindex (w:: WeightNormParam , i... ) = getindex (w. g .* w. v ./ WN_mag (w. v, w. dim), i... )
403
+ Base. ndims (w:: WeightNormParam ) = ndims (w. v)
404
+ Base. length (w:: WeightNormParam ) = length (w. v)
404
405
405
- Flux . @functor WeightNormWeight
406
+ @functor WeightNormParam
406
407
407
- WN_mag (p, dim) = sqrt .(sum (abs2 .(p), dims = dim))
408
- WN_dir (p, mag, eps ) = p ./ (mag .+ eps)
409
- WN_dir (p, mag) = WN_dir (p, mag, eps ( eltype (p)))
408
+ WN_mag (p, dim, eps ) = sqrt .(sum (abs2 .(p), dims = dim)) .+ eps
409
+ WN_mag (p, dim ) = WN_mag (p, dim, eps ( eltype (p)) )
410
+ WN_dir (p, mag) = p ./ mag
410
411
411
412
import Base.* , Base./ , Base.+ , Base.-
412
413
for f in (:+ , :- , :* , :/ )
413
- @eval ($ f)(z:: AbstractArray , w:: WeightNormWeight ) = ($ f)(z, w. g .* w. v ./ WN_mag (w. v, w. dim))
414
- @eval ($ f)(w:: WeightNormWeight , z:: AbstractArray ) = ($ f)(w. g .* w. v ./ WN_mag (w. v, w. dim), z)
414
+ @eval ($ f)(z:: AbstractArray , w:: WeightNormParam ) = ($ f)(z, w. g .* w. v ./ WN_mag (w. v, w. dim))
415
+ @eval ($ f)(w:: WeightNormParam , z:: AbstractArray ) = ($ f)(w. g .* w. v ./ WN_mag (w. v, w. dim), z)
415
416
end
416
417
417
- struct WeightNorm{L,E,I,W }
418
+ struct WeightNorm{L}
418
419
layer:: L
419
- eps:: E
420
- weight:: W
421
- dim:: I
420
+ eps:: Number
421
+ weight:: Vector
422
+ dim:: Vector
422
423
end
423
424
424
- Flux . @functor WeightNorm
425
+ @functor WeightNorm
425
426
426
427
function Base. show (io:: IO , wn:: WeightNorm )
427
428
print (io, " WeightNorm(" , wn. layer, " , " , wn. weight, " , " , wn. dim, " )" )
428
429
end
429
430
430
- function WeightNorm (layer, weight:: Union{Symbol,Int} , dim)
431
+ function WeightNorm (layer, weight:: Vector , dim:: Vector )
431
432
# Expose layer fields and constructor
432
433
func, re = Flux. functor (layer)
433
434
# Get the fields
434
435
par = [getfield (layer, fn) for fn in keys (func)]
435
- w = getfield (layer, weight)
436
- g = WN_mag (w, dim)
437
- v = WN_dir (w, g)
438
- par[findfirst (keys (func) .== weight)] = WeightNormWeight (g, v, dim)
436
+ w = map (weight) do W
437
+ getfield (layer, W)
438
+ end
439
+ g = map ((W, D) -> WN_mag (W, D), w, dim)
440
+ v = map ((W, G) -> WN_dir (W, G), w, g)
441
+ par[indexin (weight,collect (keys (func)))] = WeightNormParam .(g, v, dim)
439
442
return WeightNorm (re (par), eps (Float32), weight, dim)
440
443
end
441
444
445
+ WeightNorm (layer, weight:: Symbol , dim:: Vector ) = WeightNorm (layer, [weight], dim)
446
+ WeightNorm (layer, weight:: Symbol , dim:: Integer ) = WeightNorm (layer, [weight], [dim])
447
+ WeightNorm (layer, weight:: Vector , dim:: Integer ) = WeightNorm (layer, weight, [dim for _ in axes (weight,1 )])
448
+
442
449
function (wn:: WeightNorm )(x)
443
450
wn. layer (x)
444
451
end
0 commit comments