Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fp16 aten_native_batch_norm when bias is None and training is True #1217

Merged
merged 9 commits into from
Jan 23, 2024
22 changes: 14 additions & 8 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5555,10 +5555,14 @@ def aten_native_batch_norm(
"""native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)"""

if weight is None: # Set to 1.0 as default
weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2))
weight = op.CastLike(
op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)), input
)
BowenBao marked this conversation as resolved.
Show resolved Hide resolved

if bias is None: # Set to 0.0 as default
bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))
bias = op.CastLike(
op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)), input
)

axes = list(range(len(input.shape)))
axes.pop(1)
Expand Down Expand Up @@ -5609,13 +5613,14 @@ def _aten_native_batch_norm_training_onnx(
training_mode=training,
)
# Compute var and rstd
mean = op.ReduceMean(input, axes)
input_sub_mean = op.Sub(input, mean)
upcast_input = op.Cast(input, to=FLOAT.dtype)
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
mean = op.ReduceMean(upcast_input, axes)
input_sub_mean = op.Sub(upcast_input, mean)
sqr = op.Mul(input_sub_mean, input_sub_mean)
var = op.ReduceMean(sqr, axes, keepdims=False)
rstd = op.Div(1.0, op.Sqrt(var + eps))
# Get mean again with size = [1, C]
mean = op.ReduceMean(input, axes, keepdims=False)
mean = op.ReduceMean(upcast_input, axes, keepdims=False)
return norm, mean, rstd

justinchuby marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -5724,13 +5729,14 @@ def _aten__native_batch_norm_training_functional_onnx(
training_mode=training,
)
# Compute var and rstd
mean = op.ReduceMean(input, axes)
input_sub_mean = op.Sub(input, mean)
upcast_input = op.Cast(input, to=FLOAT.dtype)
mean = op.ReduceMean(upcast_input, axes)
input_sub_mean = op.Sub(upcast_input, mean)
sqr = op.Mul(input_sub_mean, input_sub_mean)
var = op.ReduceMean(sqr, axes, keepdims=False)
rstd = op.Div(1.0, op.Sqrt(var + eps))
# Get mean again with size = [1, C]
mean = op.ReduceMean(input, axes, keepdims=False)
mean = op.ReduceMean(upcast_input, axes, keepdims=False)
# NOTE: Fixed to be FLOAT dtype
running_mean = op.Cast(running_mean, to=FLOAT.dtype)
running_var = op.Cast(running_var, to=FLOAT.dtype)
Expand Down
Loading