From dd4757bd9ab19f24692d44b4b7d9bd74ee238356 Mon Sep 17 00:00:00 2001 From: Alex Rosen <64982042+alexrosen45@users.noreply.github.com> Date: Mon, 8 May 2023 14:41:19 -0400 Subject: [PATCH] Make BatchNorm twice-differentiable Solves the BatchNorm not being twice-differentiable problem first reported at https://discourse.julialang.org/t/compilation-error-in-zygote-flux-and-cuda-interaction/92571 and made into issue #2154. --- src/cuda/cudnn.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 24226ab4b1..e442ef665d 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -13,9 +13,9 @@ end function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...) y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) - function batchnorm_pullback(Δ) - grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...) - (NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent()) + function batchnorm_pullback(Δ, σ²Δ) + grad, σ²grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...) + (NoTangent(), grad..., NoTangent(), NoTangent(), σ²grad..., NoTangent()) end y, batchnorm_pullback end