Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add op(native group norm) | feat(atenlib) #644

Merged
merged 30 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
47447c1
update
xiaowuhu Mar 30, 2023
081c55c
Update core.py
xiaowuhu Mar 30, 2023
f056fd7
Update ops_correctness_test.py
xiaowuhu Apr 3, 2023
5802188
update
xiaowuhu Apr 17, 2023
4fb749d
Update core.py
xiaowuhu Apr 17, 2023
4f49db6
Update core.py
xiaowuhu Apr 17, 2023
5500bd1
Update core.py
xiaowuhu Apr 17, 2023
9e85674
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 17, 2023
77808ef
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 17, 2023
35885a9
Merge branch 'xiaowu/addOp(native_group_norm)' of https://github.com/…
xiaowuhu Apr 17, 2023
264de39
update
xiaowuhu Apr 17, 2023
4761f0d
Update core.py
xiaowuhu Apr 17, 2023
27dce7e
Update core.py
xiaowuhu Apr 17, 2023
f40f0d3
update
xiaowuhu Apr 18, 2023
6411efe
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 18, 2023
f90e095
update
xiaowuhu Apr 18, 2023
e624822
update
xiaowuhu Apr 18, 2023
497659a
Update extra_opinfo.py
xiaowuhu Apr 18, 2023
c83905c
update
xiaowuhu Apr 19, 2023
fac4c13
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 19, 2023
8f68565
Update core.py
xiaowuhu Apr 19, 2023
6b89bb1
Merge branch 'xiaowu/addOp(native_group_norm)' of https://github.com/…
xiaowuhu Apr 19, 2023
a8fc8f7
Update core.py
xiaowuhu Apr 19, 2023
e50c817
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 19, 2023
de9359b
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 19, 2023
e2924f4
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 20, 2023
958820b
update
xiaowuhu Apr 20, 2023
b38af29
Update core.py
xiaowuhu Apr 20, 2023
c3e3b67
Update core.py
xiaowuhu Apr 20, 2023
7005a61
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
xiaowuhu committed Apr 19, 2023
commit c83905c767a8c9389308e0dfbafd485482b2dfbf
72 changes: 44 additions & 28 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4213,49 +4213,65 @@ def aten_native_group_norm(
input: TFloat,
weight: Optional[TFloat],
bias: Optional[TFloat],
N: INT64 = None, # pylint: disable=unused-argument
C: INT64 = None, # pylint: disable=unused-argument
HxW: INT64 = None, # pylint: disable=unused-argument
group: int = None,
eps: float = None,
N: Optional[INT64] = None, # pylint: disable=unused-argument
C: Optional[INT64] = None, # pylint: disable=unused-argument
HxW: Optional[INT64] = None, # pylint: disable=unused-argument
group: Optional[int] = None,
eps: Optional[float] = 1e-05,
) -> Tuple[TFloat, TFloat, TFloat]:
"""native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)"""

# Assert(weight is not None, and, bias is not None)
# Create weight_instance_norm and bias_instance_norm
weight_inst = op.Constant(value_floats=[1.0] * group)
bias_inst = op.Constant(value_floats=[0.0] * group)
# 0 in the shape list keeps dimension value unchanged, for InstanceNorm need
shape = op.Constant(value_ints=[0, group, -1])
# Actually we don't need N,C,HxW value because the input tensor has that information
if group is None:
group = 1 # Equal to LayerNorm

norm = _aten_native_group_norm_onnx(
input, weight, bias, weight_inst, bias_inst, shape, eps
)
# weight_inst, bias_inst are fake output, because we must return 3 outpurs
return norm, weight_inst, bias_inst
if weight is None: # Set to 1.0 as default, the shape is Channel size
weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2))

if bias is None: # Set to 0.0 as default, the shape is Channel size
bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))

norm, fake_mean, fake_var = _aten_native_group_norm_onnx(input, weight, bias, group, eps)
# FIXME: return fake value because we must return 3 outputs(norm, mean, var)
# We know how the 'mean' was computed in Torch, but don't know the 'var'
return norm, fake_mean, fake_var


@torch_op("aten::native_group_norm", private=True)
def _aten_native_group_norm_onnx(
input: TFloat,
weight: Optional[TFloat],
bias: Optional[TFloat],
weight_inst: TFloat,
bias_inst: TFloat,
shape: INT64,
eps: float = None,
) -> TReal:
# Using InstanceNorm to simulate GroupNorm, because GroupNorm need weight[group] and bias[group]
# But the input is weight[channel] and bias[channel]
input_reshaped = op.Reshape(input, shape)
weight: TFloat,
bias: TFloat,
group: int,
eps: float,
) -> Tuple[TFloat, TFloat, TFloat]:
# Using InstanceNorm to simulate op.GroupNorm, because op.GroupNorm need weight[group] and bias[group]
# But the input is weight[channel] and bias[channel], the size mismatched
# Create weight_instance_norm and bias_instance_norm
shape_group = op.Reshape(op.Constant(value_int=group), op.Constant(value_ints=[-1]))
# 0 in the shape list keeps dimension value unchanged, for InstanceNorm need
shape_input = op.Concat(
op.Constant(value_ints=[0]),
shape_group,
op.Constant(value_ints=[-1]),
axis=0
)
input_reshaped = op.Reshape(input, shape_input)
weight_inst_norm = op.Expand(op.Constant(value_floats=[1.0]), shape_group)
bias_inst_norm = op.Expand(op.Constant(value_floats=[0.0]), shape_group)
norm_reshaped = op.InstanceNormalization(
input_reshaped, weight_inst, bias_inst, epsilon=eps
input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps
)
norm = op.Reshape(norm_reshaped, op.Shape(input))
input_rank = op.Size(op.Shape(input))
axes = op.Range(1, input_rank - 1, 1)
# Using the real weight and bias to computer again
return op.Add(op.Mul(norm, op.Unsqueeze(weight, axes)), op.Unsqueeze(bias, axes))
# But need to unsqueeze to the target shape for broading cast easy
weight_full_shape = op.Unsqueeze(weight, axes)
bias_full_shape = op.Unsqueeze(bias, axes)
norm_mul_weight = op.Mul(norm, weight_full_shape)
norm_result = op.Add(norm_mul_weight, bias_full_shape)
return norm_result, weight_inst_norm, bias_inst_norm


def aten_native_group_norm_backward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1011,11 +1011,6 @@ def _where_input_wrangler(
or (len(sample.args) > 0 and not isinstance(sample.args[0], int)),
reason="this ATen overload only support one tensor as input and another int as args",
),
skip(
"native_group_norm",
matcher=lambda sample: len(sample.input.shape) == 2,
reason="ONNX only support input shape >= 3",
),
skip(
"new_ones",
matcher=lambda sample: sample.kwargs.get("dtype") is not None,
Expand Down