Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add op(native group norm) | feat(atenlib) #644

Merged
merged 30 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
47447c1
update
xiaowuhu Mar 30, 2023
081c55c
Update core.py
xiaowuhu Mar 30, 2023
f056fd7
Update ops_correctness_test.py
xiaowuhu Apr 3, 2023
5802188
update
xiaowuhu Apr 17, 2023
4fb749d
Update core.py
xiaowuhu Apr 17, 2023
4f49db6
Update core.py
xiaowuhu Apr 17, 2023
5500bd1
Update core.py
xiaowuhu Apr 17, 2023
9e85674
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 17, 2023
77808ef
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 17, 2023
35885a9
Merge branch 'xiaowu/addOp(native_group_norm)' of https://github.com/…
xiaowuhu Apr 17, 2023
264de39
update
xiaowuhu Apr 17, 2023
4761f0d
Update core.py
xiaowuhu Apr 17, 2023
27dce7e
Update core.py
xiaowuhu Apr 17, 2023
f40f0d3
update
xiaowuhu Apr 18, 2023
6411efe
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 18, 2023
f90e095
update
xiaowuhu Apr 18, 2023
e624822
update
xiaowuhu Apr 18, 2023
497659a
Update extra_opinfo.py
xiaowuhu Apr 18, 2023
c83905c
update
xiaowuhu Apr 19, 2023
fac4c13
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 19, 2023
8f68565
Update core.py
xiaowuhu Apr 19, 2023
6b89bb1
Merge branch 'xiaowu/addOp(native_group_norm)' of https://github.com/…
xiaowuhu Apr 19, 2023
a8fc8f7
Update core.py
xiaowuhu Apr 19, 2023
e50c817
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 19, 2023
de9359b
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 19, 2023
e2924f4
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 20, 2023
958820b
update
xiaowuhu Apr 20, 2023
b38af29
Update core.py
xiaowuhu Apr 20, 2023
c3e3b67
Update core.py
xiaowuhu Apr 20, 2023
7005a61
Merge branch 'main' into xiaowu/addOp(native_group_norm)
xiaowuhu Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 71 additions & 10 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4288,19 +4288,80 @@ def aten_native_dropout_backward(
raise NotImplementedError()


@torch_op("aten::native_group_norm", trace_only=True)
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved
def aten_native_group_norm(
input: TensorType,
weight: Optional[TensorType],
bias: Optional[TensorType],
N: INT64,
C: INT64,
HxW: INT64,
group: int,
eps: float,
) -> tuple[TensorType, TensorType, TensorType]:
input: TFloat,
weight: Optional[TFloat] = None,
bias: Optional[TFloat] = None,
N: Optional[INT64] = None, # pylint: disable=unused-argument
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved
C: Optional[INT64] = None, # pylint: disable=unused-argument
HxW: Optional[INT64] = None, # pylint: disable=unused-argument
group: int = 1,
eps: float = 1e-05,
) -> Tuple[TFloat, TFloat, TFloat]:
"""native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)"""

raise NotImplementedError()
# Actually we don't need N,C,HxW value because the input tensor has that information
if weight is None: # Set to 1.0 as default, the shape is Channel size
weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2))

if bias is None: # Set to 0.0 as default, the shape is Channel size
bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))

# Accoding to Torch, return rstd instead of var
norm, mean, rstd = _aten_native_group_norm_onnx(input, weight, bias, group, eps)
return norm, mean, rstd


@torch_op("aten::native_group_norm", private=True)
def _aten_native_group_norm_onnx(
input: TFloat,
weight: TFloat,
bias: TFloat,
group: INT64,
eps: float,
) -> Tuple[TFloat, TFloat, TFloat]:
# Because onnx.GroupNorm() need size=group for weight and bias
# But the torch's aten function's input need size=channel, the size mismatched
# So we have to use onnx.InstanceNorm() to simulate
neg_1 = op.Constant(value_ints=[-1])
# Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter
group_tensor = op.Reshape(group, neg_1)
# 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1]
shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0)
input_reshaped = op.Reshape(input, shape_input)
weight_inst_norm = op.Expand(op.Constant(value_floats=[1.0]), group_tensor)
bias_inst_norm = op.Expand(op.Constant(value_floats=[0.0]), group_tensor)
norm = op.InstanceNormalization(
input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps
)
# Reshape back to input's shape
norm = op.Reshape(norm, op.Shape(input))
# Using the input weight and bias to do affine
# But need to unsqueeze to the target shape for broading cast easy
input_rank = op.Size(op.Shape(input))
axes_unsqueeze = op.Range(1, input_rank - 1, 1)
weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze)
bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze)
norm_mul_weight = op.Mul(norm, weight_full_shape)
norm_result = op.Add(norm_mul_weight, bias_full_shape)
# Compute mean and rstd, but using Torch algorithm
# The returned shape for mean and vstd should be [N, group, -1]
N = op.Shape(input, start=0, end=1)
shape_N_group_neg1 = op.Concat(N, group_tensor, neg_1, axis=0)
input_N_group_neg1 = op.Reshape(input, shape_N_group_neg1)
# The output size is [N, group], so dims = [2]
axes = op.Constant(value_ints=[2])
# Get mean which size is [N, group, 1], for broadcasting
mean = op.ReduceMean(input_N_group_neg1, axes)
input_sub_mean = op.Sub(input_N_group_neg1, mean)
sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean)
# In Pytorch, vstd = 1/(sqrt(var + eps))
var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=0)
rstd = op.Div(1.0, op.Sqrt(var + eps))
# Get the correct shape [N, group] for mean again
mean = op.ReduceMean(input_N_group_neg1, axes, keepdims=0)
return norm_result, mean, rstd


def aten_native_group_norm_backward(
Expand Down
41 changes: 41 additions & 0 deletions onnxscript/tests/function_libs/torch_aten/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,39 @@ def sample_inputs_max_pool3d_with_indices(op_info, device, dtype, requires_grad,
yield opinfo_core.SampleInput(arg, kwargs=kwargs)


def sample_inputs_native_group_norm(op_info, device, dtype, requires_grad, **kwargs):
del op_info
make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)

# Ordered as input shape, C,N,HxW, and kwargs for group and eps
cases = (
((1, 6, 3), (6,), (6,), 1, 6, 3, {"group": 2, "eps": 0.5}),
((2, 6, 3), (6,), (6,), 2, 6, 3, {"group": 3, "eps": -0.5}),
((5, 5, 5), (5,), (5,), 5, 5, 5, {"group": 1, "eps": 1e-5}),
((5, 8, 10), (8,), (8,), 5, 8, 10, {"group": 4, "eps": 1e-5}),
)

for input_shape, weight, bias, N, C, HxW, kwargs in cases:
# args: running mean, running var, weight and bias should necessarily be of shape: (channels,)
channels = input_shape[1] if len(input_shape) > 1 else 0
weight = make_arg(channels) if channels > 0 else None
bias = make_arg(channels) if channels > 0 else None

yield opinfo_core.SampleInput(
make_arg(input_shape),
args=(
weight,
bias,
N,
C,
HxW,
),
kwargs=kwargs,
)


def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs):
del op_info
# input_shape, output_size, kernal, dilation, padding, stride
Expand Down Expand Up @@ -316,6 +349,14 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs):
skips=(),
supports_out=False,
),
opinfo_core.OpInfo(
"native_group_norm",
op=torch.ops.aten.native_group_norm,
aten_name="native_group_norm",
dtypes=common_dtype.floating_and_complex_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_native_group_norm,
supports_out=False,
),
opinfo_core.OpInfo(
"nn.functional.conv3d",
aliases=("conv3d",),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ def _where_input_wrangler(
"index_select": core_ops.aten_index_select,
"layer_norm": core_ops.aten_layer_norm,
"max": core_ops.aten_max,
"native_group_norm": core_ops.aten_native_group_norm,
"native_batch_norm": core_ops.aten_native_batch_norm,
"native_layer_norm": core_ops.aten_native_layer_norm,
"new_empty": core_ops.aten_new_empty,
Expand Down