Skip to content

Commit 601e2d8

Browse files
authored
Merge pull request #586 from KristofferC/kc/batchnorm
work around extreme slowdown in BatchNorm due to julia performance bug in broadcast fusion
2 parents fe712bf + 9914c53 commit 601e2d8

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/layers/normalise.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ function (BN::BatchNorm)(x)
138138
end
139139

140140
let λ = BN.λ
141-
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) .+ reshape(β, affine_shape...))
141+
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ))
142+
# This is intentionally not fused because of an extreme slowdown doing so
143+
λ.(temp .+ reshape(β, affine_shape...))
142144
end
143145
end
144146

test/layers/normalisation.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,9 @@ end
9898
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
9999
@test m(x) == y
100100
end
101+
102+
let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
103+
m(x)
104+
@test (@allocated m(x)) < 100_000_000
105+
end
101106
end

0 commit comments

Comments
 (0)