Skip to content

Commit

Permalink
Make BatchNorm twice-differentiable
Browse files Browse the repository at this point in the history
Solves the BatchNorm not being twice-differentiable problem first reported at https://discourse.julialang.org/t/compilation-error-in-zygote-flux-and-cuda-interaction/92571 and made into issue FluxML#2154.
  • Loading branch information
alexrosen45 authored May 8, 2023
1 parent 7859403 commit dd4757b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ end

function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...)
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
function batchnorm_pullback(Δ)
grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...)
(NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent())
function batchnorm_pullback, σ²Δ)
grad, σ²grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...)
(NoTangent(), grad..., NoTangent(), NoTangent(), σ²grad..., NoTangent())
end
y, batchnorm_pullback
end

0 comments on commit dd4757b

Please sign in to comment.