Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add op(native group norm) | feat(atenlib) #644

Merged
merged 30 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
47447c1
update
xiaowuhu Mar 30, 2023
081c55c
Update core.py
xiaowuhu Mar 30, 2023
f056fd7
Update ops_correctness_test.py
xiaowuhu Apr 3, 2023
5802188
update
xiaowuhu Apr 17, 2023
4fb749d
Update core.py
xiaowuhu Apr 17, 2023
4f49db6
Update core.py
xiaowuhu Apr 17, 2023
5500bd1
Update core.py
xiaowuhu Apr 17, 2023
9e85674
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 17, 2023
77808ef
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 17, 2023
35885a9
Merge branch 'xiaowu/addOp(native_group_norm)' of https://github.com/…
xiaowuhu Apr 17, 2023
264de39
update
xiaowuhu Apr 17, 2023
4761f0d
Update core.py
xiaowuhu Apr 17, 2023
27dce7e
Update core.py
xiaowuhu Apr 17, 2023
f40f0d3
update
xiaowuhu Apr 18, 2023
6411efe
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 18, 2023
f90e095
update
xiaowuhu Apr 18, 2023
e624822
update
xiaowuhu Apr 18, 2023
497659a
Update extra_opinfo.py
xiaowuhu Apr 18, 2023
c83905c
update
xiaowuhu Apr 19, 2023
fac4c13
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 19, 2023
8f68565
Update core.py
xiaowuhu Apr 19, 2023
6b89bb1
Merge branch 'xiaowu/addOp(native_group_norm)' of https://github.com/…
xiaowuhu Apr 19, 2023
a8fc8f7
Update core.py
xiaowuhu Apr 19, 2023
e50c817
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 19, 2023
de9359b
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 19, 2023
e2924f4
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 20, 2023
958820b
update
xiaowuhu Apr 20, 2023
b38af29
Update core.py
xiaowuhu Apr 20, 2023
c3e3b67
Update core.py
xiaowuhu Apr 20, 2023
7005a61
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4208,19 +4208,50 @@ def aten_native_dropout_backward(
raise NotImplementedError()


@torch_op("aten::native_group_norm", trace_only=True)
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved
def aten_native_group_norm(
input: TensorType,
weight: Optional[TensorType],
bias: Optional[TensorType],
N: INT64,
C: INT64,
HxW: INT64,
group: int,
eps: float,
) -> tuple[TensorType, TensorType, TensorType]:
input: TFloat,
weight: Optional[TFloat],
bias: Optional[TFloat],
N: INT64 = None, # pylint: disable=unused-argument
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved
C: INT64 = None, # pylint: disable=unused-argument
HxW: INT64 = None, # pylint: disable=unused-argument
group: int = None,
eps: float = None,
) -> TFloat:
# FIXME: for the return, we can only return one TReal instead of [x,y,z]
# Because we don't how to computer the running_var and running_mean
# No native_group_norm test case, and the group_norm function in torch only return one output
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved
"""native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)"""

raise NotImplementedError()
# Create weight_instance_norm and bias_instance_norm
weight_inst = op.Constant(value_floats=[1.0] * group)
bias_inst = op.Constant(value_floats=[0.0] * group)
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved
# 0 in the shape list keeps dimension value unchanged, for InstanceNorm need
shape = op.Constant(value_ints=[0, group, -1])
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved

return _aten_native_group_norm_onnx(
input, weight, bias, weight_inst, bias_inst, shape, eps
)


@torch_op("aten::native_group_norm", private=True)
def _aten_native_group_norm_onnx(
input: TFloat,
weight: Optional[TFloat],
bias: Optional[TFloat],
weight_inst: TFloat,
bias_inst: TFloat,
shape: INT64,
eps: float = None,
) -> TReal:
input_reshaped = op.Reshape(input, shape)
norm_reshaped = op.InstanceNormalization(input_reshaped, weight_inst, bias_inst, epsilon=eps)
norm = op.Reshape(norm_reshaped, op.Shape(input))
input_rank = op.Size(op.Shape(input))
axes = op.Range(1, input_rank - 1, 1)
# Using the real weight and bias to computer again
return op.Add(op.Mul(norm, op.Unsqueeze(weight, axes)), op.Unsqueeze(bias, axes))


def aten_native_group_norm_backward(
Expand Down
19 changes: 19 additions & 0 deletions onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,17 @@ def _mse_loss_input_wrangler(
return args, kwargs


def _native_group_norm_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
kwargs["group"] = args.pop(1) # move group(int) to kwargs as attribute
args.append(kwargs["weight"]) # move weight(tensor) to args as input
args.append(kwargs["bias"]) # move bias(tensor) to args as input
del kwargs["weight"]
del kwargs["bias"]
return args, kwargs


def _nll_loss_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -701,6 +712,7 @@ def _where_input_wrangler(
"index_select": core_ops.aten_index_select,
"layer_norm": core_ops.aten_layer_norm,
"max": core_ops.aten_max,
"native_group_norm": (core_ops.aten_native_group_norm, _native_group_norm_input_wrangler),
"native_layer_norm": core_ops.aten_native_layer_norm,
"new_empty": core_ops.aten_new_empty,
"new_empty_strided": core_ops.aten_new_empty_strided,
Expand Down Expand Up @@ -997,6 +1009,11 @@ def _where_input_wrangler(
or (len(sample.args) > 0 and not isinstance(sample.args[0], int)),
reason="this ATen overload only support one tensor as input and another int as args",
),
skip(
"native_group_norm",
matcher=lambda sample: len(sample.input.shape) == 2,
reason="ONNX only support input shape >= 3",
),
skip(
"new_ones",
matcher=lambda sample: sample.kwargs.get("dtype") is not None,
Expand Down Expand Up @@ -1213,6 +1230,8 @@ def _where_input_wrangler(

duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))

duplicate_opinfo(OPS_DB, "nn.functional.group_norm", ("native_group_norm",))

duplicate_opinfo(OPS_DB, "new_ones", ("new_ones_dtype",))

duplicate_opinfo(OPS_DB, "new_zeros", ("new_zeros_dtype",))
Expand Down