Skip to content

Commit b25397b

Browse files
committed
WeightNormWeight is now called WeightNormParam
WeightNorm for several params, single dim Test for Scalar and Vector dims Test newly created WN equality Simplified some bits Missing last constructor
1 parent 41feb43 commit b25397b

File tree

3 files changed

+43
-30
lines changed

3 files changed

+43
-30
lines changed

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ export gradient
1111

1212
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
1313
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
14-
WeightNorm, WeightNormWeight, SkipConnection, params, fmap, cpu, gpu, f32, f64
14+
WeightNorm, WeightNormParam, SkipConnection, params, fmap, cpu, gpu, f32, f64
1515

1616
include("optimise/Optimise.jl")
1717
using .Optimise

src/layers/normalise.jl

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,13 @@ end
371371
Weight Normalization.
372372
This layer reparametrizes weights (w) of a layer with its decomposition into magnitude (g) and direction (v).
373373
374-
WeightNorm(layer, weight::Union{Symbol,Int}, dim)
374+
WeightNorm(layer, weight, dim)
375375
376376
``layer`` is the layer being normalized.
377377
378-
``weight`` is the parameter to be normalized.
378+
``weight`` are the parameters to be normalized.
379379
380-
``dim`` is the dimension of normalization.
380+
``dim`` are the dimension of normalization.
381381
Often, its the dimension encoding the output channels.
382382
383383
Example:
@@ -390,55 +390,62 @@ wndB = WeightNorm(d, :W, 1:2); #Now we normalize all directions together, keepin
390390
Link : https://arxiv.org/pdf/1602.07868.pdf
391391
"""
392392

393-
struct WeightNormWeight{T,N,I}
393+
struct WeightNormParam{T,N,I}
394394
g::AbstractArray{T,N}
395395
v::AbstractArray{T,N}
396396
dim::I
397397
end
398398

399-
Base.size(w::WeightNormWeight, i...) = size(w.v, i...)
400-
Base.size(w::WeightNormWeight) = size(w.v)
401-
Base.iterate(w::WeightNormWeight, i...) = iterate(w.g .* w.v ./ WN_mag(w.v, w.dim), i...)
402-
Base.getindex(w::WeightNormWeight, i...) = getindex(w.g .* w.v ./ WN_mag(w.v, w.dim), i...)
403-
Base.ndims(w::WeightNormWeight) = ndims(w.v)
399+
Base.size(w::WeightNormParam, i...) = size(w.v, i...)
400+
Base.size(w::WeightNormParam) = size(w.v)
401+
Base.iterate(w::WeightNormParam, i...) = iterate(w.g .* w.v ./ WN_mag(w.v, w.dim), i...)
402+
Base.getindex(w::WeightNormParam, i...) = getindex(w.g .* w.v ./ WN_mag(w.v, w.dim), i...)
403+
Base.ndims(w::WeightNormParam) = ndims(w.v)
404+
Base.length(w::WeightNormParam) = length(w.v)
404405

405-
Flux.@functor WeightNormWeight
406+
@functor WeightNormParam
406407

407-
WN_mag(p, dim) = sqrt.(sum(abs2.(p), dims = dim))
408-
WN_dir(p, mag, eps) = p ./ (mag .+ eps)
409-
WN_dir(p, mag) = WN_dir(p, mag, eps(eltype(p)))
408+
WN_mag(p, dim, eps) = sqrt.(sum(abs2.(p), dims = dim)) .+ eps
409+
WN_mag(p, dim) = WN_mag(p, dim, eps(eltype(p)))
410+
WN_dir(p, mag) = p ./ mag
410411

411412
import Base.*, Base./, Base.+, Base.-
412413
for f in (:+, :-, :*, :/)
413-
@eval ($f)(z::AbstractArray, w::WeightNormWeight) = ($f)(z, w.g .* w.v ./ WN_mag(w.v, w.dim))
414-
@eval ($f)(w::WeightNormWeight, z::AbstractArray) = ($f)(w.g .* w.v ./ WN_mag(w.v, w.dim), z)
414+
@eval ($f)(z::AbstractArray, w::WeightNormParam) = ($f)(z, w.g .* w.v ./ WN_mag(w.v, w.dim))
415+
@eval ($f)(w::WeightNormParam, z::AbstractArray) = ($f)(w.g .* w.v ./ WN_mag(w.v, w.dim), z)
415416
end
416417

417-
struct WeightNorm{L,E,I,W}
418+
struct WeightNorm{L}
418419
layer::L
419-
eps::E
420-
weight::W
421-
dim::I
420+
eps::Number
421+
weight::Vector
422+
dim::Vector
422423
end
423424

424-
Flux.@functor WeightNorm
425+
@functor WeightNorm
425426

426427
function Base.show(io::IO, wn::WeightNorm)
427428
print(io, "WeightNorm(", wn.layer, ", ", wn.weight, ", ", wn.dim, ")")
428429
end
429430

430-
function WeightNorm(layer, weight::Union{Symbol,Int}, dim)
431+
function WeightNorm(layer, weight::Vector, dim::Vector)
431432
#Expose layer fields and constructor
432433
func, re = Flux.functor(layer)
433434
#Get the fields
434435
par = [getfield(layer, fn) for fn in keys(func)]
435-
w = getfield(layer, weight)
436-
g = WN_mag(w, dim)
437-
v = WN_dir(w, g)
438-
par[findfirst(keys(func) .== weight)] = WeightNormWeight(g, v, dim)
436+
w = map(weight) do W
437+
getfield(layer, W)
438+
end
439+
g = map((W, D) -> WN_mag(W, D), w, dim)
440+
v = map((W, G) -> WN_dir(W, G), w, g)
441+
par[indexin(weight,collect(keys(func)))] = WeightNormParam.(g, v, dim)
439442
return WeightNorm(re(par), eps(Float32), weight, dim)
440443
end
441444

445+
WeightNorm(layer, weight::Symbol, dim::Vector) = WeightNorm(layer, [weight], dim)
446+
WeightNorm(layer, weight::Symbol, dim::Integer) = WeightNorm(layer, [weight], [dim])
447+
WeightNorm(layer, weight::Vector, dim::Integer) = WeightNorm(layer, weight, [dim for _ in axes(weight,1)])
448+
442449
function (wn::WeightNorm)(x)
443450
wn.layer(x)
444451
end

test/layers/normalisation.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,20 +196,26 @@ end
196196
d = Dense(10, 9, tanh)
197197
gs = gradient(() -> sum(abs2, d(fake_data)), params(d))
198198
W = d.W
199-
for WN_dim in [1, 2, 1:2]
199+
for WN_dim in [[1], 1, [2], 2, [1:2]]
200200
wnd = WeightNorm(d, :W, WN_dim)
201201
gswn = gradient(() -> sum(abs2, wnd(fake_data)), params(wnd))
202202
g = wnd.layer.W.g
203203
v = wnd.layer.W.v
204-
normv = sum(abs2, v, dims = WN_dim)
205204

206205
ΔW = gs[W]
207206
Δg = gswn[g]
208207
Δv = gswn[v]
209-
@test sum(ΔW .* v ./ normv, dims = WN_dim) Δg
208+
@test wnd(fake_data) d(fake_data)
209+
if isa(WN_dim, Int)
210+
normv = sum(abs2, v, dims = WN_dim)
211+
@test sum(ΔW .* v ./ normv, dims = WN_dim) Δg
212+
else
213+
normv = sum(abs2, v, dims = WN_dim[1])
214+
@test sum(ΔW .* v ./ normv, dims = WN_dim[1]) Δg
215+
end
210216
@test g ./ normv .* ΔW - g .* Δg .* v ./ (normv.^2) Δv
211217
@test size(Δv) == size(ΔW)
212-
@test isa(wnd.layer.W, WeightNormWeight)
218+
@test isa(wnd.layer.W, Flux.WeightNormParam)
213219
end
214220
end
215221
end

0 commit comments

Comments
 (0)