-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow ForwardDiff in BatchNorm's track_stats #2127
Conversation
@@ -275,8 +275,8 @@ function _track_stats!( | |||
μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N)) | |||
σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N)) | |||
|
|||
bn.μ = res_mtm .* bn.μ .+ mtm .* μnew | |||
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new | |||
bn.μ .= value.(res_mtm .* bn.μ .+ mtm .* μnew) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also updates bn.μ
in-place, rather than replacing the field. We know it's mutable as the constructor made it so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we leave a comment explaining why value
is there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, done.
2a06bcb
to
2dca7ec
Compare
Bump? |
Fixes... well, fixes the error seen in #2122 in the narrowest way, by explicitly removing duals. Does not cause ForwardDiff to put the layer into training mode.
Has a test, which ought to survive even if some more elaborate solution is found.