Skip to content

Commit

Permalink
changes based on PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dhpollack committed Feb 27, 2019
1 parent 577dbf8 commit 1004faa
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ m = Chain(
```
"""
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)

mutable struct InstanceNorm{F,V,W,N}
λ::F # activation function
β::V # bias
Expand Down Expand Up @@ -233,8 +234,8 @@ function (IN::InstanceNorm)(x)

# update moving mean/std
mtm = data(convert(T, IN.momentum))
IN.μ = reshape(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :)
IN.σ² = reshape(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ reshape(data(σ²), (c, bs)) .* (mtm * m / (m - 1))), dims = 2), :)
IN.μ = dropdims(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
IN.σ² = dropdims(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
end

let λ = IN.λ
Expand Down
17 changes: 16 additions & 1 deletion test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,22 @@ end
@test m(x) == y
end

let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
# check that μ, σ², and the output are the correct size for higher rank tensors
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = m(x)
@test size(m.μ) == (sizes[end - 1], )
@test size(m.σ²) == (sizes[end - 1], )
@test size(y) == sizes
end

# show that instance norm is equal to batch norm when channel and batch dims are squashed
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
end

let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
m(x)
@test (@allocated m(x)) < 100_000_000
end
Expand Down

0 comments on commit 1004faa

Please sign in to comment.