diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index d0ecd01eb..7fe76a6de 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -6,7 +6,7 @@ on: - main - 'gh/**/base' # ghstack base branches pull_request: - types: [opened, synchronize, reopened, ready_for_review] + merge_group: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 64609c070..417fd908d 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -13,6 +13,7 @@ on: # Allows you to run this workflow manually from the Actions tab workflow_dispatch: + merge_group: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} diff --git a/docs/tutorial/rewriter/examples/erfgelu.py b/docs/tutorial/rewriter/examples/erfgelu.py index a7f16cea0..f32ade37c 100644 --- a/docs/tutorial/rewriter/examples/erfgelu.py +++ b/docs/tutorial/rewriter/examples/erfgelu.py @@ -87,7 +87,7 @@ def erf_gelu_pattern_2(op, x): def gelu(op, x: ir.Value): - return op.Gelu(x, domain="com.microsoft") + return op.Gelu(x, _domain="com.microsoft") #################################### diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index ab97c5f98..c1a2afbfb 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -100,6 +100,10 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True): "cannot import module, import_module does not work", ), skip("^test_bitwise_not_3d", "cannot import module, import_module does not work"), + skip( + "^test_resize_upsample_scales_linear_half_pixel_symmetric", + "cannot import module, import_module does not work", + ), ) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 4fac129ef..bef78a799 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -459,8 +459,18 @@ def _add_attribute_to_torchscript_node( return node.fs_(key, list(value)) # type: ignore[arg-type] if isinstance(value[0], int): return node.is_(key, list(value)) # type: ignore[attr-defined] - raise TypeError(f"Unsupported sequence type '{type(value)}' for attribute '{key}'") - raise TypeError(f"Unsupported attribute type '{type(value)}' for attribute '{key}'") + raise TypeError( + f"Unsupported sequence type '{type(value)}' for attribute '{key}' in " + f"node={node!r}, value is {value!r}" + ) + if "TensorProtoDataType" in str(type(value)): + # torch._C._onnx.TensorProtoDataType + return node.i_(key, int(value)) + + raise TypeError( + f"Unsupported attribute type '{type(value)}' for attribute '{key}' " + f"in node={node!r}, value is {value!r}" + ) @runtime_typing.checked diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f984ed6b9..d7e97e98d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -45,7 +45,6 @@ TInt, TReal, TRealOrUInt8, - TRealUnlessFloat16OrInt8, TRealUnlessInt16OrInt8, TTensor, TTensor2, @@ -83,11 +82,15 @@ def aten__log_softmax_half( ) -> FLOAT: """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - # trace_only because we need to cast conditionally based on half_to_float + self_is_scalar = IsScalar(self) if half_to_float: self = op.Cast(self, to=FLOAT.dtype) - - return aten__log_softmax(self, dim, half_to_float) + if self_is_scalar: + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + result = op.LogSoftmax(self, axis=dim) + if self_is_scalar: + result = op.Squeeze(result, op.Constant(value_ints=[0])) + return result @torch_op("aten::_log_softmax", traceable=True) @@ -102,7 +105,7 @@ def aten__log_softmax( if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.LogSoftmax(self, axis=dim) - if self_is_scalar: # squeeze to scalar due to input is scalar + if self_is_scalar: result = op.Squeeze(result) return result @@ -150,14 +153,14 @@ def aten_abs_complex(self: TRealOrUInt8) -> TRealOrUInt8: return op.Squeeze(op.Sqrt(real_plus_imag), axes=[-1]) -@torch_op("aten::acos") +@torch_op("aten::acos", traceable=True) def aten_acos(self: TFloat) -> TFloat: """acos(Tensor self) -> Tensor""" return op.Acos(self) -@torch_op("aten::acosh") +@torch_op("aten::acosh", traceable=True) def aten_acosh(self: TFloat) -> TFloat: """acosh(Tensor self) -> Tensor""" @@ -542,7 +545,7 @@ def _integral_to_be_adjusted(dtype: int) -> bool: @torch_op("aten::arange", trace_only=True) def aten_arange( - end: Union[DOUBLE, FLOAT, INT16, INT32, INT64], + end: float, dtype: int = -1, layout: str = "", device: str = "", @@ -550,13 +553,11 @@ def aten_arange( ) -> TensorType: """arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. - - if dtype == -1: - zero = op.CastLike(0.0, end) - one = op.CastLike(1.0, end) - result = op.Range(zero, end, one) + if dtype == -1 or dtype is None: + if isinstance(end, int): + result = op.Range(0, end, 1) + else: + result = op.Range(0.0, end, 1.0) elif _range_supported(dtype): end = op.Cast(end, to=dtype) zero = op.Cast(0, to=dtype) @@ -567,7 +568,7 @@ def aten_arange( # because the input dtype may be e.g. bfloat16 / int8 etc. # which Range does not support. The output type is ensured because the output # is casted to the specified dtype. - end = op.Cast(end, to=FLOAT.dtype) + end = op.Constant(value_float=float(end)) zero = op.Constant(value_float=0.0) one = op.Constant(value_float=1.0) result = op.Cast(op.Range(zero, end, one), to=dtype) @@ -577,8 +578,8 @@ def aten_arange( @torch_op("aten::arange.start", trace_only=True) def aten_arange_start( - start: TRealUnlessFloat16OrInt8, - end: TRealUnlessFloat16OrInt8, + start: float, + end: float, dtype: int = -1, layout: str = "", device: str = "", @@ -586,12 +587,13 @@ def aten_arange_start( ) -> TensorType: """arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. - - if dtype == -1: - one = op.CastLike(1.0, end) - result = op.Range(start, end, one) + if dtype == -1 or dtype is None: + if isinstance(start, int) and isinstance(end, int): + result = op.Range(start, end, 1) + else: + start = float(start) + end = float(end) + result = op.Range(start, end, 1.0) elif _range_supported(dtype): end = op.Cast(end, to=dtype) start = op.Cast(start, to=dtype) @@ -602,36 +604,32 @@ def aten_arange_start( # because the input dtype may be e.g. bfloat16 / int8 etc. # which Range does not support. The output type is ensured because the output # is casted to the specified dtype. - end = op.Cast(end, to=FLOAT.dtype) - start = op.Cast(start, to=FLOAT.dtype) + end = op.Constant(value_float=float(end)) + start = op.Constant(value_float=float(start)) one = op.Constant(value_float=1.0) result = op.Cast(op.Range(start, end, one), to=dtype) return result -@torch_op("aten::arange.start_step", private=True) def _adjust_args_for_arange_int_dtype( - start: TRealUnlessFloat16OrInt8, - end: TRealUnlessFloat16OrInt8, - step: TRealUnlessFloat16OrInt8, -) -> Tuple[FLOAT, FLOAT, FLOAT]: - zero = op.Cast(0.0, to=FLOAT.dtype) - start = op.Cast(start, to=FLOAT.dtype) - end = op.Cast(end, to=FLOAT.dtype) - step = op.Cast(step, to=FLOAT.dtype) - - start = op.Where(op.Less(start, zero), op.Ceil(start), start) - start = op.Where(op.Less(step, zero), op.Floor(start), start) + start: float, + end: float, + step: float, +) -> Tuple[float, float, float]: + if start < 0: + start = math.ceil(start) + if step < 0: + start = math.floor(start) - return (start, end, step) + return float(start), float(end), float(step) @torch_op("aten::arange.start_step", trace_only=True) def aten_arange_start_step( - start: TRealUnlessFloat16OrInt8, - end: TRealUnlessFloat16OrInt8, - step: TRealUnlessFloat16OrInt8, + start: float, + end: float, + step: float = 1.0, dtype: int = -1, layout: str = "", device: str = "", @@ -639,11 +637,14 @@ def aten_arange_start_step( ) -> TensorType: """arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. - - if dtype == -1: - result = op.Range(start, end, step) + if dtype == -1 or dtype is None: + if isinstance(start, int) and isinstance(end, int): + result = op.Range(start, end, int(step)) + else: + start = float(start) + end = float(end) + step = float(step) + result = op.Range(start, end, step) elif _integral_to_be_adjusted(dtype): # PyTorch arange op handles these integral types differently from INT64, # so we have to adjust these arguments accordingly. @@ -651,18 +652,18 @@ def aten_arange_start_step( start, end, step = _adjust_args_for_arange_int_dtype(start, end, step) result = op.Cast(op.Range(start, end, step), to=dtype) elif dtype == INT64.dtype: - end = op.Cast(end, to=dtype) - start = op.Cast(start, to=dtype) - step = op.Cast(step, to=dtype) + end = int(end) + start = int(start) + step = int(step) result = op.Range(start, end, step) else: # Cast input to float if dtype is not supported by Range, # because the input dtype may be e.g. bfloat16, # which Range does not support. The output type is ensured because the output # is casted to the specified dtype. - end = op.Cast(end, to=FLOAT.dtype) - start = op.Cast(start, to=FLOAT.dtype) - step = op.Cast(step, to=FLOAT.dtype) + end = float(end) + start = float(start) + step = float(step) result = op.Cast(op.Range(start, end, step), to=dtype) return result @@ -894,21 +895,21 @@ def aten_as_strided_scatter( raise NotImplementedError() -@torch_op("aten::asin") +@torch_op("aten::asin", traceable=True) def aten_asin(self: TFloat) -> TFloat: """asin(Tensor self) -> Tensor""" return op.Asin(self) -@torch_op("aten::asinh") +@torch_op("aten::asinh", traceable=True) def aten_asinh(self: TFloat) -> TFloat: """asinh(Tensor self) -> Tensor""" return op.Asinh(self) -@torch_op("aten::atan") +@torch_op("aten::atan", traceable=True) def aten_atan(self: TFloat) -> TFloat: """atan(Tensor self) -> Tensor""" @@ -929,7 +930,7 @@ def aten_atan2(self: TFloat, other: TFloat) -> TFloat: return result -@torch_op("aten::atanh") +@torch_op("aten::atanh", traceable=True) def aten_atanh(self: TFloat) -> TFloat: """atanh(Tensor self) -> Tensor""" @@ -1229,6 +1230,7 @@ def aten_bitwise_and(self: TInt, other: TInt) -> TInt: "aten::bitwise_left_shift.Tensor", "aten::bitwise_left_shift.Tensor_Scalar", "aten::bitwise_left_shift.Scalar_Tensor", + "_operator::__lshift__", ), traceable=True, ) @@ -1248,6 +1250,7 @@ def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16: "aten::bitwise_left_shift.Tensor", "aten::bitwise_left_shift.Tensor_Scalar", "aten::bitwise_left_shift.Scalar_Tensor", + "_operator::__lshift__", ), traceable=True, ) @@ -1267,6 +1270,7 @@ def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32: "aten::bitwise_left_shift.Tensor", "aten::bitwise_left_shift.Tensor_Scalar", "aten::bitwise_left_shift.Scalar_Tensor", + "_operator::__lshift__", ), traceable=True, ) @@ -1286,6 +1290,7 @@ def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64: "aten::bitwise_left_shift.Tensor", "aten::bitwise_left_shift.Tensor_Scalar", "aten::bitwise_left_shift.Scalar_Tensor", + "_operator::__lshift__", ), traceable=True, ) @@ -1329,6 +1334,7 @@ def aten_bitwise_or(self: TInt, other: TInt) -> TInt: "aten::bitwise_right_shift.Tensor", "aten::bitwise_right_shift.Tensor_Scalar", "aten::bitwise_right_shift.Scalar_Tensor", + "_operator::__rshift__", ) ) def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: @@ -1358,6 +1364,7 @@ def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: "aten::bitwise_right_shift.Tensor", "aten::bitwise_right_shift.Tensor_Scalar", "aten::bitwise_right_shift.Scalar_Tensor", + "_operator::__rshift__", ) ) def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: @@ -1387,6 +1394,7 @@ def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: "aten::bitwise_right_shift.Tensor", "aten::bitwise_right_shift.Tensor_Scalar", "aten::bitwise_right_shift.Scalar_Tensor", + "_operator::__rshift__", ) ) def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: @@ -1419,6 +1427,7 @@ def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: "aten::bitwise_right_shift.Tensor", "aten::bitwise_right_shift.Tensor_Scalar", "aten::bitwise_right_shift.Scalar_Tensor", + "_operator::__rshift__", ) ) def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: @@ -1655,40 +1664,32 @@ def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = return clamped -@torch_op("aten::clamp_max", traceable=True) +@torch_op(("aten::clamp_max", "aten::clamp_max.Tensor"), traceable=True) def aten_clamp_max(self: TReal, max_: TReal) -> TReal: """clamp_max(Tensor self, Tensor max) -> Tensor""" - self_size = op.Size(self) - max_shape = op.Shape(max_) - max_rank = op.Size(max_shape) - if self_size == 0: - result = op.Expand(self, max_shape) + # This implementation does not intent to handle when self is an empty tensor + max_rank = Rank(max_) + if max_rank == 0: + max_ = op.CastLike(max_, self) + result = op.Clip(self, None, max_) else: - if max_rank == 0: - max_ = op.CastLike(max_, self) - result = op.Clip(self, None, max_) - else: - result = op.Min(self, max_) + result = op.Min(self, max_) return result -@torch_op("aten::clamp_min", traceable=True) +@torch_op(("aten::clamp_min", "aten::clamp_min.Tensor"), traceable=True) def aten_clamp_min(self: TReal, min_: TReal) -> TReal: """clamp_min(Tensor self, Tensor min) -> Tensor""" - self_size = op.Size(self) - min_shape = op.Shape(min_) - min_rank = op.Size(min_shape) - if self_size == 0: - result = op.Expand(self, min_shape) + # This implementation does not intent to handle when self is an empty tensor + min_rank = Rank(min_) + if min_rank == 0: + min_ = op.CastLike(min_, self) + result = op.Clip(self, min_, None) else: - if min_rank == 0: - min_ = op.CastLike(min_, self) - result = op.Clip(self, min_, None) - else: - result = op.Max(self, min_) + result = op.Max(self, min_) return result @@ -2194,14 +2195,14 @@ def aten_corrcoef(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::cos") +@torch_op("aten::cos", traceable=True) def aten_cos(self: TFloat) -> TFloat: """cos(Tensor self) -> Tensor""" return op.Cos(self) -@torch_op("aten::cosh") +@torch_op("aten::cosh", traceable=True) def aten_cosh(self: TFloat) -> TFloat: """cosh(Tensor self) -> Tensor""" @@ -2545,19 +2546,11 @@ def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - return _aten_diagonal_onnx(self, offset, dim1, dim2, perm, axes) - - -@torch_op("aten::diagonal", private=True, traceable=True) -def _aten_diagonal_onnx( - self: TTensor, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int] -) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - tmp_tensor = op.ConstantOfShape(mask_shape) - mask = op.EyeLike(tmp_tensor, k=offset) + mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) mask = op.CastLike(mask, self) self_t = op.Transpose(self, perm=perm) result = op.Mul(self_t, mask) @@ -2583,18 +2576,19 @@ def _aten_diagonal_onnx( # 6 0 4 0 # From above table, we can get the logic below + offset_val = op.Constant(value_ints=[offset]) if offset < 0: # row + offset - length = dim1_size + offset + length = op.Add(dim1_size, offset_val) start = op.Constant(value_ints=[0]) else: # offset >= 0 # col - offset - length = dim2_size - offset - start = op.Reshape(op.Constant(value_int=offset), neg_1) + length = op.Sub(dim2_size, offset_val) + start = offset_val # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), 0) - end = start + length + length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) + end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) return result @@ -2624,19 +2618,11 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - return _aten_diagonal_bool_onnx(self, offset, dim1, dim2, perm, axes) - - -@torch_op("aten::diagonal", private=True) -def _aten_diagonal_bool_onnx( - self: BOOL, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int] -) -> BOOL: neg_1 = op.Constant(value_ints=[-1]) dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - tmp_tensor = op.ConstantOfShape(mask_shape) - mask = op.EyeLike(tmp_tensor, k=offset) + mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) self_int = op.Cast(self, to=INT64.dtype) mask_int = op.Cast(mask, to=INT64.dtype) self_int_t = op.Transpose(self_int, perm=perm) @@ -2663,18 +2649,19 @@ def _aten_diagonal_bool_onnx( # 6 0 4 0 # From above table, we can get the logic below + offset_val = op.Constant(value_ints=[offset]) if offset < 0: # row + offset - length = dim1_size + offset + length = op.Add(dim1_size, offset_val) start = op.Constant(value_ints=[0]) else: # offset >= 0 # col - offset - length = dim2_size - offset - start = op.Reshape(op.Constant(value_int=offset), neg_1) + length = op.Sub(dim2_size, offset_val) + start = offset_val # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), 0) - end = start + length + length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) + end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) result = op.Cast(result, to=BOOL.dtype) @@ -3489,7 +3476,7 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType raise NotImplementedError() -@torch_op(("aten::fill.Tensor", "aten::fill.Sclaar")) +@torch_op(("aten::fill.Tensor", "aten::fill.Scalar")) def aten_fill(self: TTensor, value: TTensor2) -> TTensor: """fill.Tensor(Tensor self, Tensor value) -> Tensor""" @@ -3606,30 +3593,31 @@ def aten_from_file( @torch_op("aten::full", trace_only=True) def aten_full( - size: INT64, - fill_value: FLOAT, + size: Union[INT64, INT32], + fill_value: TensorType, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", pin_memory: bool = False, -): +) -> TensorType: """full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - size = op.Cast(size, to=INT64.dtype) if dtype != -1: fill_value = op.Cast(fill_value, to=dtype) + + size = op.Cast(size, to=INT64.dtype) return op.Expand(fill_value, size) @torch_op("aten::full_like", trace_only=True) def aten_full_like( - self: TTensor, - fill_value: TTensor, + self: TensorType, + fill_value: TensorType, dtype: int = -1, layout: str = "", device: str = "", pin_memory: bool = False, -) -> TTensor: +) -> TensorType: """full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" if dtype == -1: @@ -3830,19 +3818,6 @@ def aten_grid_sampler_3d_backward( raise NotImplementedError() -def aten_group_norm( - input: TensorType, - num_groups: int, - weight: Optional[TensorType] = None, - bias: Optional[TensorType] = None, - eps: float = 1e-05, - cudnn_enabled: bool = True, -) -> TensorType: - """group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor""" - - raise NotImplementedError() - - def aten_gru_cell( input: TensorType, hx: TensorType, @@ -4715,11 +4690,17 @@ def aten_linear_backward( @torch_op("aten::linspace", trace_only=True) def aten_linspace( - start: TFloat, end: TFloat, steps: int, dtype: int = FLOAT.dtype + start: float, + end: float, + steps: int, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TensorType: """linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - if dtype == -1: + if dtype == -1 or dtype is None: dtype = FLOAT.dtype # Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896 @@ -4728,6 +4709,7 @@ def aten_linspace( if steps == 1: return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype) + # TODO(justinchuby): Simplify the logic knowing start and end are floats rg = aten_arange_start(0, steps, dtype=dtype) start = op.Cast(start, to=dtype) end = op.Cast(end, to=dtype) @@ -4743,14 +4725,14 @@ def aten_linspace( ) -@torch_op("aten::log") +@torch_op("aten::log", traceable=True) def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """log(Tensor self) -> Tensor""" return op.Log(self) -@torch_op("aten::log10") +@torch_op("aten::log10", traceable=True) def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """log10(Tensor self) -> Tensor""" @@ -4764,21 +4746,21 @@ def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16: return op.Log(op.Add(self, 1.0)) -@torch_op("aten::log2") +@torch_op("aten::log2", traceable=True) def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """log2(Tensor self) -> Tensor""" return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self)) -@torch_op("aten::logaddexp") +@torch_op("aten::logaddexp", traceable=True) def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: """logaddexp(Tensor self, Tensor other) -> Tensor""" return op.Log(op.Add(op.Exp(self), op.Exp(other))) -@torch_op("aten::logaddexp2") +@torch_op("aten::logaddexp2", traceable=True) def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: """logaddexp2(Tensor self, Tensor other) -> Tensor""" two = op.CastLike(2.0, self) @@ -4811,7 +4793,7 @@ def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: return result -@torch_op("aten::logdet") +@torch_op("aten::logdet", traceable=True) def aten_logdet(self: TFloat) -> TFloat: """logdet(Tensor self) -> Tensor""" @@ -4824,7 +4806,8 @@ def aten_logdet(self: TFloat) -> TFloat: "aten::bitwise_and.Tensor", "aten::bitwise_and.Scalar", "aten::bitwise_and.Scalar_Tensor", - ) + ), + traceable=True, ) def aten_logical_and(self: BOOL, other: BOOL) -> BOOL: """logical_and(Tensor self, Tensor other) -> Tensor""" @@ -4832,7 +4815,7 @@ def aten_logical_and(self: BOOL, other: BOOL) -> BOOL: return op.And(self, other) -@torch_op(("aten::logical_not", "aten::bitwise_not")) +@torch_op(("aten::logical_not", "aten::bitwise_not"), traceable=True) def aten_logical_not(self: BOOL) -> BOOL: """logical_not(Tensor self) -> Tensor""" @@ -4863,7 +4846,8 @@ def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: "aten::bitwise_xor.Tensor", "aten::bitwise_xor.Scalar", "aten::bitwise_xor.Scalar_Tensor", - ) + ), + traceable=True, ) def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: """logical_xor(Tensor self, Tensor other) -> Tensor""" @@ -4912,12 +4896,6 @@ def aten_logsumexp(self: TFloat, dim: INT64, keepdim: int = False) -> TFloat: return result -def aten_lshift(self: TensorType, other: TensorType) -> TensorType: - """__lshift__.Tensor(Tensor self, Tensor other) -> Tensor""" - - raise NotImplementedError() - - def aten_lstm_cell( input: TensorType, hx: Sequence[TensorType], @@ -6075,7 +6053,9 @@ def _aten_native_group_norm_onnx( axes_unsqueeze = op.Range(1, input_rank - 1, 1) weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) + weight_full_shape = op.CastLike(weight_full_shape, norm) norm_mul_weight = op.Mul(norm, weight_full_shape) + bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight) norm_result = op.Add(norm_mul_weight, bias_full_shape) # Compute mean and rstd, but using Torch algorithm # The returned shape for mean and vstd should be [N, group, -1] @@ -6226,7 +6206,7 @@ def aten_new_empty_strided( def aten_new_full( self: TTensor, size: INT64, - fill_value: TTensor, + fill_value: TensorType, dtype: int = -1, layout: str = "", device: str = "", @@ -7308,12 +7288,6 @@ def aten_rrelu( raise NotImplementedError() -def aten_rshift(self: TensorType, other: TensorType) -> TensorType: - """__rshift__.Tensor(Tensor self, Tensor other) -> Tensor""" - - raise NotImplementedError() - - @torch_op("aten::rsqrt") def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """rsqrt(Tensor self) -> Tensor""" @@ -7543,14 +7517,14 @@ def aten_signbit(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::sin") +@torch_op("aten::sin", traceable=True) def aten_sin(self: TFloat) -> TFloat: """sin(Tensor self) -> Tensor""" return op.Sin(self) -@torch_op("aten::sinh") +@torch_op("aten::sinh", traceable=True) def aten_sinh(self: TFloat) -> TFloat: """sinh(Tensor self) -> Tensor""" @@ -7845,16 +7819,90 @@ def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrStr return op.ConcatFromSequence(tensors, axis=dim, new_axis=1) -def aten_std(self: TensorType, unbiased: bool = True) -> TensorType: +@torch_op("aten::std", trace_only=True) +def aten_std(self: TReal, unbiased: bool = True) -> TReal: """std(Tensor self, bool unbiased=True) -> Tensor""" + var = _aten_var_onnx(self, correction=float(unbiased), keepdim=False) + return op.Sqrt(var) - raise NotImplementedError() +@torch_op("aten::std.dim", trace_only=True) +def aten_std_dim( + self: TReal, + dim: Sequence[int], + unbiased: Optional[bool] = True, + keepdim: Optional[bool] = False, +) -> TReal: + """std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor""" + + var = _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim) + return op.Sqrt(var) -def aten_std_mean(self: TensorType, unbiased: bool = True) -> tuple[TensorType, TensorType]: + +@torch_op("aten::var.correction", trace_only=True) +def aten_std_correction( + self: TReal, + # FIXME(justinchuby): Make dim Optional[Sequence[int]] + dim: Optional[int] = None, + correction: Optional[float] = None, + keepdim: bool = False, +) -> TReal: + """std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor""" + + if correction is None: + correction = 1.0 + + if dim is None: + var = _aten_var_onnx(self, correction=correction, keepdim=keepdim) + else: + var = _aten_var_dim_onnx(self, dims=dim, correction=correction, keepdim=keepdim) + return op.Sqrt(var) + + +@torch_op("aten::std_mean", trace_only=True) +def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: """std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)""" - raise NotImplementedError() + # Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction" + # If not this case, should be explicitly set correction value according to unbiased value + var, mean = _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False) + return op.Sqrt(var), mean + + +@torch_op("aten::std_mean.dim", trace_only=True) +def aten_std_mean_dim( + self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False +) -> Tuple[TReal, TReal]: + """std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)""" + + # Although dim is Optional in signature, but we assume it must have value for this overload + # Assert(dim is not None) + var, mean = _aten_var_mean_dim_onnx( + self, dims=dim, correction=float(unbiased), keepdim=keepdim + ) + return op.Sqrt(var), mean + + +@torch_op("aten::std_mean.correction", trace_only=True) +def aten_std_mean_correction( + self: TReal, + # FIXME(justinchuby): Make dim Optional[Sequence[int]] + dim: Optional[int] = None, + correction: Optional[float] = None, + keepdim: bool = False, +) -> Tuple[TReal, TReal]: + """std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)""" + + if correction is None: + correction = 1.0 + + if dim is None: + var, mean = _aten_var_mean_onnx(self, correction=correction, keepdim=keepdim) + else: + var, mean = _aten_var_mean_dim_onnx( + self, dims=dim, correction=correction, keepdim=keepdim + ) + return op.Sqrt(var), mean @torch_op("aten::stft", private=True) @@ -8916,7 +8964,14 @@ def aten_zeros( @torch_op("aten::zeros_like", trace_only=True) -def aten_zeros_like(self: TTensor, dtype: int = -1) -> TTensor: +def aten_zeros_like( + self: TTensor, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, + memory_format: str = "", +) -> TTensor: """zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" # NOTE: trace_only because both if branches need to be the same type, but we have diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 5e0da20d0..37298f3a9 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -27,6 +27,7 @@ TFloat, TFloatOrBFloat16, TFloatOrUInt8, + TInt, TReal, TTensor, ) @@ -40,58 +41,7 @@ TFloatUnlessFloat32 = TypeVar("TFloatUnlessFloat32", bound=Union[BFLOAT16, FLOAT16, DOUBLE]) -@torch_op("aten::adaptive_avg_pool1d", traceable=True) -def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat: - """adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor""" - - # assert output_size == [1] - # TODO(justinchuby): Specify input constraints - - if Rank(self) == 2: - # Unbatched case - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - pooled = op.GlobalAveragePool(self) - result = op.Squeeze(pooled, op.Constant(value_ints=[0])) - else: - result = op.GlobalAveragePool(self) - - return result - - -@torch_op("aten::adaptive_avg_pool2d", traceable=True) -def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat: - """adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor""" - - # assert output_size == [1, 1] - # TODO(justinchuby): Specify input constraints - - if Rank(self) == 3: - # Unbatched case - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - pooled = op.GlobalAveragePool(self) - result = op.Squeeze(pooled, op.Constant(value_ints=[0])) - else: - result = op.GlobalAveragePool(self) - - return result - - -@torch_op("aten::adaptive_avg_pool3d", traceable=True) -def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat: - """adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor""" - - # assert output_size == [1, 1, 1] - # TODO(justinchuby): Specify input constraints - - if Rank(self) == 4: - # Unbatched case - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - pooled = op.GlobalAveragePool(self) - result = op.Squeeze(pooled, op.Constant(value_ints=[0])) - else: - result = op.GlobalAveragePool(self) - - return result +# NOTE: Implementations of adaptive_average_pool are handled by torch decomp def aten_adaptive_max_pool1d( @@ -593,6 +543,56 @@ def aten_glu_backward_jvp( raise NotImplementedError() +@torch_op("aten::group_norm", trace_only=True) +def aten_group_norm( + input: TFloat, + num_groups: int, + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, + eps: float = 1e-05, + cudnn_enabled: bool = True, +) -> TensorType: + """group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor""" + + # Actually we don't need N,C,HxW value because the input tensor has that information + if weight is None: # Set to 1.0 as default, the shape is Channel size + weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)) + + if bias is None: # Set to 0.0 as default, the shape is Channel size + bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) + + # Because onnx.GroupNorm() need size=group for weight and bias + # But the torch's aten function's input need size=channel, the size mismatched + # So we have to use onnx.InstanceNorm() to simulate + neg_1 = op.Constant(value_ints=[-1]) + # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter + group_tensor = op.Reshape(num_groups, neg_1) + # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1] + shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0) + input_reshaped = op.Reshape(input, shape_input) + weight_inst_norm = op.Expand( + op.CastLike(op.Constant(value_float=1.0), input), group_tensor + ) + bias_inst_norm = op.Expand(op.CastLike(op.Constant(value_float=0.0), input), group_tensor) + norm = op.InstanceNormalization( + input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps + ) + # Reshape back to input's shape + norm = op.Reshape(norm, op.Shape(input)) + # Using the input weight and bias to do affine + # But need to unsqueeze to the target shape for broading cast easy + input_rank = Rank(input) + one = op.Constant(value_int=1) + axes_unsqueeze = op.Range(one, op.Sub(input_rank, one), one) + weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) + bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) + weight_full_shape = op.CastLike(weight_full_shape, norm) + norm_mul_weight = op.Mul(norm, weight_full_shape) + bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight) + norm_result = op.Add(norm_mul_weight, bias_full_shape) + return norm_result + + def aten_glu_jvp(glu: TensorType, x: TensorType, dx: TensorType, dim: int) -> TensorType: """glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor""" @@ -659,16 +659,138 @@ def aten_huber_loss_backward( raise NotImplementedError() +def _get_im2col_indices_along_dim( + input_d: TInt, + kernel_size_d: int, + dilation_d: int, + padding_d: int, + stride_d: int, +): + # Input is always 4-D (N, C, H, W) + # Calculate indices of sliding blocks along spatial dimension + # Slide kernel over input each dim d: + # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) + # with steps = stride + + blocks_d = input_d + ((padding_d * 2) - (dilation_d * (kernel_size_d - 1))) + + # Stride kernel over input and find starting indices along dim d + blocks_d_indices = op.Range(0, blocks_d, stride_d) + blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0]) + + # Apply dilation on kernel and find its indices along dim d + kernel_grid = op.Range(0, kernel_size_d * dilation_d, dilation_d) + kernel_mask = op.Unsqueeze(kernel_grid, [1]) + + # Broadcast and add kernel staring positions (indices) with + # kernel_grid along dim d, to get block indices along dim d + block_mask = op.Add(blocks_d_indices, kernel_mask) + + return block_mask + + +def _get_im2col_padded_input(input, padding_h, padding_w): + # Input is always 4-D tensor (N, C, H, W) + # Padding tensor has the following format: (padding_h, padding_w) + # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) + pad = op.Concat( + op.Constant(value_ints=[0, 0]), + op.Unsqueeze(padding_h, [0]), + op.Unsqueeze(padding_w, [0]), + op.Constant(value_ints=[0, 0]), + op.Unsqueeze(padding_h, [0]), + op.Unsqueeze(padding_w, [0]), + axis=0, + ) + return op.Pad(input, pad) + + +def _get_im2col_output_shape(input, kernel_h, kernel_w): + input_shape = op.Shape(input) + batch_dim = op.Gather(input_shape, 0, axis=0) + channel_dim = op.Gather(input_shape, 1, axis=0) + channel_unfolded = op.Mul(channel_dim, kernel_h * kernel_w) + + return op.Concat( + op.Unsqueeze(batch_dim, [0]), + op.Unsqueeze(channel_unfolded, [0]), + op.Constant(value_ints=[-1]), + axis=0, + ) + + +@torch_op("aten::im2col", trace_only=True) def aten_im2col( - self: TensorType, + self: TReal, kernel_size: Sequence[int], - dilation: Sequence[int], - padding: Sequence[int], - stride: Sequence[int], + dilation: Sequence[int] = (1, 1), + padding: Sequence[int] = (0, 0), + stride: Sequence[int] = (1, 1), ) -> TensorType: - """im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor""" + """im2col(Tensor self, int[2] kernel_size, int[2] dilation=1, int[2] padding=0, int[2] stride=1) -> Tensor""" - raise NotImplementedError() + input_shape = op.Shape(self) + input_h = op.Gather(input_shape, 2, axis=0) + input_w = op.Gather(input_shape, 3, axis=0) + + if not isinstance(kernel_size, Sequence): + kernel_size = (kernel_size, kernel_size) + kernel_sizes = list(kernel_size) + + if not isinstance(dilation, Sequence): + dilation = (dilation, dilation) + dilations = list(dilation) + + if not isinstance(padding, Sequence): + padding = (padding, padding) + pads = list(padding) + + if isinstance(stride, int): + stride = (stride, stride) + strides = list(stride) + + stride_h, stride_w = strides[0], strides[1] + padding_h, padding_w = pads[0], pads[1] + dilation_h, dilation_w = dilations[0], dilations[1] + kernel_h, kernel_w = kernel_sizes[0], kernel_sizes[1] + + blocks_row_indices = _get_im2col_indices_along_dim( + input_h, kernel_h, dilation_h, padding_h, stride_h + ) + blocks_col_indices = _get_im2col_indices_along_dim( + input_w, kernel_w, dilation_w, padding_w, stride_w + ) + + output_shape = _get_im2col_output_shape(self, kernel_h, kernel_w) + padded_input = _get_im2col_padded_input(self, padding_h, padding_w) + + # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 + # [[[[1., 2., 3.,], + # [4., 5., 6.,], + # [7., 8., 9.,]]]] + # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[1., 2., 3.], + # [4., 5., 6.]], + # [[4., 5., 6.], + # [7., 8., 9.]]]]] + # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[[1., 2.], + # [4., 5.]], + # [[2., 3.], + # [5., 6]]], + # [[[4., 5.], + # [7., 8.]], + # [[5., 6.], + # [8., 9.]]]]]] + # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: + # [[[1., 2., 4., 5.], + # [2., 3., 5., 6.], + # [4., 5., 7., 8.], + # [5., 6., 8., 9.]]] + output = op.Gather(padded_input, blocks_row_indices, axis=2) + output = op.Gather(output, blocks_col_indices, axis=4) + output = op.Transpose(output, perm=[0, 1, 2, 4, 3, 5]) + return op.Reshape(output, output_shape) def aten_infinitely_differentiable_gelu_backward( @@ -1241,10 +1363,11 @@ def aten_multilabel_margin_loss_forward( raise NotImplementedError() -@torch_op("aten::nll_loss", traceable=True) +@torch_op("aten::nll_loss", trace_only=True) def aten_nll_loss( self: TFloat, target: INT64, + weight: Optional[TFloat] = None, reduction: int = 1, ignore_index: int = -100, ) -> TFloat: @@ -1259,55 +1382,15 @@ def aten_nll_loss( target = op.Unsqueeze(target, op.Constant(value_ints=[0])) if reduction == 0: - result = op.NegativeLogLikelihoodLoss( - self, target, ignore_index=ignore_index, reduction="none" - ) + reduction_str = "none" elif reduction == 1: - result = op.NegativeLogLikelihoodLoss( - self, target, ignore_index=ignore_index, reduction="mean" - ) + reduction_str = "mean" else: # assert reduction == 2 - result = op.NegativeLogLikelihoodLoss( - self, target, ignore_index=ignore_index, reduction="sum" - ) - - if self_rank_is_1: - result = op.Squeeze(result) - - return result - + reduction_str = "sum" -@torch_op("aten::nll_loss", traceable=True) -def aten_nll_loss_weight( - self: TFloat, - target: INT64, - weight: TFloat, - reduction: int = 1, - ignore_index: int = -100, -) -> TFloat: - """nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor""" - - self_rank_is_1 = Rank(self) == 1 - if self_rank_is_1: - # self rank should be at least 2 - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - - rank_target = Rank(target) - if rank_target == 0: # target rank should be at least 1 - target = op.Unsqueeze(target, op.Constant(value_ints=[0])) - - if reduction == 0: - result = op.NegativeLogLikelihoodLoss( - self, target, weight, ignore_index=ignore_index, reduction="none" - ) - elif reduction == 1: - result = op.NegativeLogLikelihoodLoss( - self, target, weight, ignore_index=ignore_index, reduction="mean" - ) - else: - result = op.NegativeLogLikelihoodLoss( - self, target, weight, ignore_index=ignore_index, reduction="sum" - ) + result = op.NegativeLogLikelihoodLoss( + self, target, weight, ignore_index=ignore_index, reduction=reduction_str + ) if self_rank_is_1: result = op.Squeeze(result) @@ -1367,16 +1450,23 @@ def aten_nll_loss_backward( raise NotImplementedError() +@torch_op("aten::nll_loss_forward", trace_only=True) def aten_nll_loss_forward( self: TensorType, target: TensorType, weight: Optional[TensorType], reduction: int, - ignore_index: INT64, + ignore_index: int, ) -> tuple[TensorType, TensorType]: """nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)""" - raise NotImplementedError() + output = aten_nll_loss(self, target, weight, reduction, ignore_index) + # FIXME: Fake a total_weight tensor for now. It should be different based on weight, reduction and ignore_index + if weight is None: + total_weight = op.CastLike(op.Size(output), self) + else: + total_weight = op.CastLike(op.Size(output), weight) + return output, total_weight def aten_nll_loss_nd( diff --git a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py index fa2df9751..92962a9ea 100644 --- a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py +++ b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py @@ -56,6 +56,8 @@ def quantized_decomposed_dequantize_per_tensor( ) -> TensorType: # TODO(justinchuby): Use dtype when we use opset 21 dequantized = op.DequantizeLinear(input, scale, common.constant(zero_point, dtype=dtype)) - if out_dtype == -1: + if out_dtype in (-1, None): + # out_dtype can be None as well return dequantized + assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}" return op.Cast(dequantized, to=out_dtype) diff --git a/onnxscript/ir/_convenience.py b/onnxscript/ir/_convenience.py index 86d2f88c3..166e7581b 100644 --- a/onnxscript/ir/_convenience.py +++ b/onnxscript/ir/_convenience.py @@ -395,3 +395,45 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]: continue values[value.name] = value return values + + +def replace_nodes_and_values( + graph_or_function: _core.Graph | _core.Function, + /, + insertion_point: _core.Node, + old_nodes: Sequence[_core.Node], + new_nodes: Sequence[_core.Node], + old_values: Sequence[_core.Value], + new_values: Sequence[_core.Value], +) -> None: + """Replaces nodes and values in the graph or function. + + Args: + graph_or_function: The graph or function to replace nodes and values in. + insertion_point: The node to insert the new nodes after. + old_nodes: The nodes to replace. + new_nodes: The nodes to replace with. + old_values: The values to replace. + new_values: The values to replace with. + """ + + for old_value, new_value in zip(old_values, new_values): + # Propagate relevant info from old value to new value + # TODO(Rama): Perhaps this should be a separate utility function. Also, consider + # merging old and new type/shape info. + new_value.type = old_value.type + new_value.shape = old_value.shape + new_value.const_value = old_value.const_value + new_value.name = old_value.name + + # Reconnect the users of the deleted values to use the new values + replace_all_uses_with(old_values, new_values) + # Update graph/function outputs if the node generates output + replacement_mapping = dict(zip(old_values, new_values)) + for idx, graph_or_function_output in enumerate(graph_or_function.outputs): + if graph_or_function_output in replacement_mapping: + graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] + + # insert new nodes after the index node + graph_or_function.insert_after(insertion_point, new_nodes) + graph_or_function.remove(old_nodes, safe=True) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 7eeba0493..b5a29cdd4 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1647,12 +1647,12 @@ def _check_node_safe_to_remove( raise ValueError( f"Node '{node!r}' is still an output of the graph and cannot be removed when safe=True." ) - for use, _ in output.uses(): - if use in to_remove: - continue + uses_not_to_remove = [user for user, _ in output.uses() if user not in to_remove] + if uses_not_to_remove: raise ValueError( - f"Node '{use!r}' is still being used by other nodes that are not to be " - f"removed. All of its uses: {list(output.uses())!r}" + f"Output value '{output!r}' is still being used by other nodes that are not to be " + f"removed. All of its users that is not being removed: {uses_not_to_remove!r}. " + "Please make sure these nodes are no longer using the output value." ) diff --git a/onnxscript/rewriter/_tape.py b/onnxscript/ir/_tape.py similarity index 50% rename from onnxscript/rewriter/_tape.py rename to onnxscript/ir/_tape.py index 8ebed05fa..0a179af85 100644 --- a/onnxscript/rewriter/_tape.py +++ b/onnxscript/ir/_tape.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Iterable, Mapping, Sequence +from typing import Any, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple from onnxscript import ir from onnxscript.ir import _convenience @@ -19,8 +19,8 @@ class Tape(Iterable[ir.Node]): def __init__(self) -> None: self._nodes: list[ir.Node] = [] - def __iter__(self) -> Sequence[ir.Node]: - return self._nodes + def __iter__(self) -> Iterator[ir.Node]: + return iter(self._nodes) @property def nodes(self) -> Sequence[ir.Node]: @@ -59,3 +59,46 @@ def op_multi_output( self._nodes.append(node) return node.outputs + + +# A type representing the domains/versions used in creating nodes in IR. +UsedOpsets = List[Tuple[str, Optional[int]]] + + +class Builder(Tape): + """An extension of the tape that provides a more convenient API for constructing the IR.""" + + def __init__(self): + super().__init__() + self._used_opsets: UsedOpsets = [] + + def __getattr__(self, op_type: str) -> Any: + return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) + + def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): + domain = kwargs.pop("_domain", "") + version = kwargs.pop("_version", None) + outputs = kwargs.pop("_outputs", 1) + if isinstance(outputs, Sequence): + num_outputs = len(outputs) + else: + assert isinstance(outputs, int) + num_outputs = outputs + + self._used_opsets.append((domain, version)) + if num_outputs == 1: + value = super().op(op_type, inputs=inputs, attributes=kwargs, domain=domain) + if isinstance(outputs, Sequence): + value.name = outputs[0] + return value + values = super().op_multi_output( + op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs + ) + if isinstance(outputs, Sequence): + for value, name in zip(values, outputs): + value.name = name + return values + + @property + def used_opsets(self) -> UsedOpsets: + return self._used_opsets diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py index 03140f16a..fc8416cc1 100644 --- a/onnxscript/ir/convenience.py +++ b/onnxscript/ir/convenience.py @@ -8,12 +8,14 @@ "convert_attribute", "convert_attributes", "replace_all_uses_with", + "replace_nodes_and_values", ] from onnxscript.ir._convenience import ( convert_attribute, convert_attributes, replace_all_uses_with, + replace_nodes_and_values, ) # NOTE: Do not implement any other functions in this module. diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 6140b06f7..a34b9810b 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -18,6 +18,7 @@ import onnxscript.ir._convenience as _convenience import onnxscript.optimizer.constant_folding as constant_folding import onnxscript.rewriter.pattern as orp +import onnxscript.utils.utils as utils def is_control_flow_op(node: ir.Node) -> bool: @@ -27,14 +28,13 @@ def is_control_flow_op(node: ir.Node) -> bool: def is_non_deterministic_op(node: ir.Node) -> bool: - return ( - node.op_type in constant_folding.non_deterministic_ops - and constant_folding.is_onnx_domain(node.domain) + return node.op_type in constant_folding.non_deterministic_ops and utils.is_onnx_domain( + node.domain ) def is_constant_op(node: ir.Node) -> bool: - return node.op_type in {"Constant", "ConstantOfShape"} and constant_folding.is_onnx_domain( + return node.op_type in {"Constant", "ConstantOfShape"} and utils.is_onnx_domain( node.domain ) @@ -362,7 +362,7 @@ def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValu unsqueezed_inputs = [] for node_input in inputs: unsqueezed_input = op.Unsqueeze( - node_input, axis_value, outputs=[f"{node_input.name}_unsqueeze"] + node_input, axis_value, _outputs=[f"{node_input.name}_unsqueeze"] ) unsqueezed_inputs.append(unsqueezed_input) # Send unsqueezed outputs to Concat @@ -427,13 +427,13 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: num_outputs = math.ceil(split_dimension_size / split_value.item()) split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] split_values = op.Split( - input, axis=axis, num_outputs=num_outputs, outputs=split_outputs + input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs ) elif split_value.ndim == 1: # split into 'size(split)' chunks num_outputs = split_value.size split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_values = op.Split(input, split, axis=axis, outputs=split_outputs) + split_values = op.Split(input, split, axis=axis, _outputs=split_outputs) else: return None @@ -442,11 +442,11 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None if keepdims == 0: # squeeze the split dimension if keepdims is 0 - axis_val = op.Constant(value_int=axis, outputs=[f"{output.name}_axis"]) + axis_val = op.Constant(value_int=axis, _outputs=[f"{output.name}_axis"]) squeezed_values = [] for i in range(num_outputs): squeezed = op.Squeeze( - split_values[i], axis_val, outputs=[f"{split_outputs[i]}_squeeze"] + split_values[i], axis_val, _outputs=[f"{split_outputs[i]}_squeeze"] ) squeezed_values.append(squeezed) split_values = squeezed_values @@ -648,32 +648,11 @@ def convert(av): def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name) - # TODO: what about new opset_imports? - old_values = node.outputs - new_values = replacement.new_outputs - for old_value, new_value in zip(old_values, new_values): - # Propagate relevant info from old value to new value - # TODO(Rama): Perhaps we should merge old and new types. As of now, new - # values don't have type information. Note that this could be a problem - # for semantics-altering rewrite-rules: we should allow users to override - # this for such rules. - new_value.type = old_value.type - new_value.shape = old_value.shape - new_value.const_value = old_value.const_value - new_value.name = old_value.name - - # Reconnect the users of the deleted node to use the new outputs - _convenience.replace_all_uses_with(old_values, new_values) - # Update graph/function outputs if the node generates output - replacement_mapping = dict(zip(old_values, new_values)) - for idx, graph_or_function_output in enumerate(root.outputs): - if graph_or_function_output in replacement_mapping: - root.outputs[idx] = replacement_mapping[graph_or_function_output] - - # insert new nodes after the index node - root.insert_after(node, replacement.new_nodes) - root.remove(node, safe=True) + _convenience.replace_nodes_and_values( + root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs + ) + # TODO: what about new opset_imports? # TODO: track statistics about replaced nodes and sizes of new constants def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None: @@ -698,12 +677,17 @@ def visit_graph(self, graph: ir.Graph) -> None: for node in graph: self.visit_node(node, graph) + def visit_function(self, function: ir.Function) -> None: + for node in function: + self.visit_node(node, function) + def visit_model(self, model: ir.Model) -> None: self._init() self.opset_imports = model.opset_imports self.visit_graph(model.graph) - # TODO(rama): handle functions - # Pending decision on whether we want to specialize functions or not. + for function in model.functions.values(): + # TODO(rama): Should we specialize functions? + self.visit_function(function) def fold_constants( diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/cast_constant_of_shape.py index bd58af933..34656ff19 100644 --- a/onnxscript/rewriter/cast_constant_of_shape.py +++ b/onnxscript/rewriter/cast_constant_of_shape.py @@ -9,7 +9,6 @@ from onnxscript import ir from onnxscript.rewriter import pattern -op = pattern.onnxop logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/erfgelu.py b/onnxscript/rewriter/erfgelu.py index ea8d27a4e..c821a79b3 100644 --- a/onnxscript/rewriter/erfgelu.py +++ b/onnxscript/rewriter/erfgelu.py @@ -21,7 +21,7 @@ def erf_gelu_pattern(op, x): # Replacement def gelu(op, x): - return op.Gelu(x, domain="com.microsoft") + return op.Gelu(x, _domain="com.microsoft") rule = pattern.RewriteRule(erf_gelu_pattern, gelu) diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/gemm_to_matmul_add.py index 0b9ee373b..bff77839f 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/gemm_to_matmul_add.py @@ -3,8 +3,6 @@ from onnxscript.rewriter import pattern from onnxscript.rewriter.broadcast_to_matmul import check_if_not_need_reshape -op = pattern.onnxop - # Pattern to match against def reshape_gemm_reshape_pattern(op, input_a, input_b, input_c, shape_a, shape_c): diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index d0daf2e06..2926f5964 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -298,7 +298,7 @@ def _match_backward( return self.none(starting_node, inspect.currentframe().f_lineno) for graph_input, pattern_input in zip(graph_node.inputs, pattern_node.inputs): - if len(list(graph_input.uses())) != len(list(pattern_input.uses())): + if len(graph_input.uses()) != len(pattern_input.uses()): self._hint( "BACKWARD: one input is used outside the pattern", "-- pattern", @@ -423,12 +423,12 @@ def _match_values_forward( return match_count if len(free) < len(pattern_node_users_not_matched): # Not enough successors to match the remaining patterns. - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) if len(pattern_node_users_not_matched) == len(free) == 1: # Only one option again. graph_node = free[0] if pattern_node_users_not_matched[0].op_identifier() != graph_node.op_identifier(): - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) key = pattern_node_users_not_matched[0] if self.verbose >= 10: @@ -461,11 +461,11 @@ def _match_values_forward( "-- model-matched", pattern_node_users_matched, ) - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) for k, v in ec.items(): if gc[k] < v: # Not enough types to match. - return self.none(node, inspect.currentframe().f_lineno) + return self.none(starting_node, inspect.currentframe().f_lineno) # At this stage, we know matching the types is possible. # We first mark whatever is possible. diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index d65f01c8d..dadaf5e8b 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -12,6 +12,7 @@ import onnx.parser import onnx.reference import onnxruntime as ort +import parameterized from onnxscript import ir from onnxscript.rewriter import generic_pattern, pattern @@ -19,6 +20,13 @@ FLOAT = onnx.TensorProto.FLOAT +@parameterized.parameterized_class( + ("matcher_algo",), + [ + (generic_pattern.GenericPatternMatcher,), + (pattern.SimplePatternMatcher,), + ], +) class GenericPatternTest(unittest.TestCase): def _range(self, *shape, bias: float | None = None): n = np.prod(shape) @@ -37,7 +45,7 @@ def match_pattern(op, x, y, z): def apply_pattern(op, x, y, z, **_): """Builds the replacement graph.""" - return op.AddAdd(x, y, z, domain="ZZZ") + return op.AddAdd(x, y, z, _domain="ZZZ") def validate_mapping(context, x, y, z, **_) -> bool: """Validates the mapping.""" @@ -48,7 +56,7 @@ def validate_mapping(context, x, y, z, **_) -> bool: match_pattern, apply_pattern, validate_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, ) class AddAdd(onnx.reference.op_run.OpRun): @@ -119,7 +127,7 @@ def match_pattern(op, x, y, w, z): def apply_pattern(op, x, y, w, z, **_): """Builds the pattern to match.""" - return op.AddAddAddAdd(x, y, w, z, domain="ZZZ", outputs=2) + return op.AddAddAddAdd(x, y, w, z, _domain="ZZZ", _outputs=2) def validate_mapping(context, **_) -> bool: return True @@ -128,7 +136,7 @@ def validate_mapping(context, **_) -> bool: match_pattern, apply_pattern, validate_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -254,13 +262,9 @@ def match_pattern(op, x): return t1, t2 def apply_pattern(op, x, **_): - return op.SinCos(x, domain="com.microsoft", outputs=2) + return op.SinCos(x, _domain="com.microsoft", _outputs=2) - rule = pattern.RewriteRule( - match_pattern, - apply_pattern, - matcher=generic_pattern.GenericPatternMatcher, - ) + rule = pattern.RewriteRule(match_pattern, apply_pattern, matcher=self.matcher_algo) model_proto = onnx.parser.parse_model( """ @@ -281,20 +285,22 @@ def apply_pattern(op, x, **_): self.assertEqual(len(graph.node), 2) self.assertEqual(graph.node[0].op_type, "SinCos") - @unittest.skip("Input variable reuse not supported yet") def test_shared_root_value_extra_use(self): + if self.matcher_algo is generic_pattern.GenericPatternMatcher: + raise unittest.SkipTest("GenericPatternMatcher does not support extra uses yet.") + def match_pattern(op, x): t1 = op.Sin(x) t2 = op.Cos(x) return t1, t2 def apply_pattern(op, x, **_): - return op.SinCos(x, domain="com.microsoft", outputs=2) + return op.SinCos(x, _domain="com.microsoft", _outputs=2) rule = pattern.RewriteRule( match_pattern, apply_pattern, - matcher=generic_pattern.GenericPatternMatcher, + matcher=self.matcher_algo, ) model_proto = onnx.parser.parse_model( """ @@ -314,7 +320,7 @@ def apply_pattern(op, x, **_): rule.apply_to_model(ir_model) graph = ir_model.graph self.assertEqual(len(graph), 3) - self.assertEqual(graph.node[0].op_type, "SinCos") + self.assertEqual(graph.node(0).op_type, "SinCos") def test_rotary_embedding(self): # The test work on a model if it has the expected name. @@ -332,8 +338,8 @@ def match_pattern(op, x, pos_ids, axis): output, _length = op.ConcatTraining( transpose, transpose, - domain="com.microsoft", - outputs=2, + _domain="com.microsoft", + _outputs=2, ) sin = op.Sin(output) @@ -359,15 +365,15 @@ def apply_pattern(op, x, pos_ids, axis, **_): pos_ids, cos_cache, sin_cache, - domain="com.microsoft", - outputs=2, + _domain="com.microsoft", + _outputs=2, ) rule = pattern.RewriteRule( match_pattern, apply_pattern, validate_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -389,7 +395,8 @@ def apply_pattern(op, x, pos_ids, axis, **_): self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) out = buffer.getvalue() # TODO(Rama): What is this assertion testing? Is it to check that `verbose` is working? - self.assertIn("[GenericPatternMatcher.match", out) + if self.matcher_algo is generic_pattern.GenericPatternMatcher: + self.assertIn("[GenericPatternMatcher.match", out) def test_rotary_embedding_onnxscript(self): # The test work on a model if it has the expected name. @@ -402,7 +409,7 @@ def rotary_match_pattern(op, x, pos_ids, axis): matmul = op.MatMul(pos_ids, cast) transpose = op.Transpose(matmul) output, _length = op.ConcatTraining( - transpose, transpose, domain="com.microsoft", outputs=2 + transpose, transpose, _domain="com.microsoft", _outputs=2 ) sin = op.Sin(output) @@ -424,7 +431,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_): value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) ) part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2 + x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2 ) return part1, part2 @@ -432,7 +439,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_): rotary_match_pattern, rotary_apply_pattern, validate_rotary_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -454,7 +461,8 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_): self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) out = buffer.getvalue() # TODO(justinchuby): Remove this assert - capturing stdout is not robust - self.assertIn("[GenericPatternMatcher.match", out) + if self.matcher_algo is generic_pattern.GenericPatternMatcher: + self.assertIn("[GenericPatternMatcher.match", out) def test_rotary_emb_file_onnxscript(self): # The test work on a model if it has the expected name. @@ -467,7 +475,7 @@ def rotary_match_pattern(op, x, pos_ids, axis): matmul = op.MatMul(pos_ids, cast) transpose = op.Transpose(matmul) output, _length = op.ConcatTraining( - transpose, transpose, domain="com.microsoft", outputs=2 + transpose, transpose, _domain="com.microsoft", _outputs=2 ) sin = op.Sin(output) @@ -489,7 +497,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis): value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) ) part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2 + x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2 ) return part1, part2 @@ -504,7 +512,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis): rotary_match_pattern, rotary_apply_pattern, validate_rotary_mapping, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=10, ) @@ -527,8 +535,8 @@ def test_transpose_transpose_onnxscript(self): # return Y def transpose_transpose_pattern(op, X): - XT = op.Transpose(X, outputs=["XT"]) - Y = op.Transpose(XT, outputs=["Y"]) + XT = op.Transpose(X, _outputs=["XT"]) + Y = op.Transpose(XT, _outputs=["Y"]) return Y def transpose_transpose_mapping(perm0, perm1): @@ -561,7 +569,7 @@ def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_): transpose_transpose_pattern, transpose_transpose_apply_pattern, transpose_transpose_check, - generic_pattern.GenericPatternMatcher, + self.matcher_algo, verbose=0, ) diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index 6be58dd65..0d163d0a2 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -11,8 +11,6 @@ import onnxscript.rewriter.no_op as no_op import onnxscript.rewriter.pattern as orp -op = orp.onnxop - class CastIdentity(orp.RewriteRuleAsClass): """Replaces ``Cast(., to=to)`` by ``Identity`` if possible.""" @@ -155,7 +153,7 @@ def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> bool: @classmethod def rewrite(cls, op, x, begin0, end0, axes0, begin1, end1, axes1): - return op.Split(x, num_outputs=2, axis=-1, outputs=2) + return op.Split(x, num_outputs=2, axis=-1, _outputs=2) class TransposeIdentity(orp.RewriteRuleAsClass): diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py index 95c3e2434..7a4b00798 100644 --- a/onnxscript/rewriter/no_op.py +++ b/onnxscript/rewriter/no_op.py @@ -2,8 +2,6 @@ # Licensed under the MIT License. from onnxscript.rewriter import pattern -op = pattern.onnxop - # TODO: Support 1-D constant tensors # https://github.com/microsoft/onnx-rewriter/issues/186 diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py index 83f263304..65496ec8b 100644 --- a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py @@ -6,8 +6,6 @@ import onnxscript.rewriter.pattern as orp -op = orp.onnxop - class FusedMatMulDiv1(orp.RewriteRuleAsClass): """Replaces ``MatMul + Div`` by FusedMatMul.""" @@ -29,7 +27,7 @@ def check(cls, context, x, y, cst) -> bool: def rewrite(cls, op, x, y, cst): value = cst.const_value.numpy() c = float(value[0] if value.shape == (1,) else value) - return op.FusedMatMul(x, y, alpha=1 / c, domain="com.microsoft") + return op.FusedMatMul(x, y, alpha=1 / c, _domain="com.microsoft") class FusedMatMulDiv2(orp.RewriteRuleAsClass): @@ -37,7 +35,7 @@ class FusedMatMulDiv2(orp.RewriteRuleAsClass): @classmethod def pattern(cls, op, x, y, cst): - return op.Div(op.FusedMatMul(x, y, domain="com.microsoft"), cst) + return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst) @classmethod def check(cls, context, x, y, cst) -> bool: @@ -60,7 +58,7 @@ def rewrite(cls, op, x, y, cst): att = node.attributes.get(name) if att: kwargs[name] = att.value - return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft") + return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") class _TransposeMatMulBase(orp.RewriteRuleAsClass): @@ -83,7 +81,7 @@ def rewrite(cls, op, x, y): kwargs[name] = att.value name = "transA" if cls._pos == 1 else "transB" kwargs[name] = 1 - kwargs.get(name, 0) - return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft") + return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") class TransposeMatMul1(_TransposeMatMulBase): @@ -99,7 +97,7 @@ class TransposeFusedMatMul1(TransposeMatMul1): @classmethod def pattern(cls, op, x, y): - return op.FusedMatMul(op.Transpose(x), y, domain="com.microsoft") + return op.FusedMatMul(op.Transpose(x), y, _domain="com.microsoft") class TransposeMatMul2(_TransposeMatMulBase): @@ -117,7 +115,7 @@ class TransposeFusedMatMul2(TransposeMatMul2): @classmethod def pattern(cls, op, x, y): - return op.FusedMatMul(x, op.Transpose(y), domain="com.microsoft") + return op.FusedMatMul(x, op.Transpose(y), _domain="com.microsoft") class MatMulTranspose(orp.RewriteRuleAsClass): @@ -146,7 +144,7 @@ def rewrite(cls, op, x, y): kwargs[name] = att.value for name in ["transA", "transB"]: kwargs[name] = 1 - kwargs.get(name, 0) - return op.FusedMatMul(y, x, **kwargs, domain="com.microsoft") + return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft") class FusedMatMulTranspose(MatMulTranspose): @@ -154,7 +152,7 @@ class FusedMatMulTranspose(MatMulTranspose): @classmethod def pattern(cls, op, x, y): - return op.Transpose(op.FusedMatMul(x, y, domain="com.microsoft")) + return op.Transpose(op.FusedMatMul(x, y, _domain="com.microsoft")) def fused_matmul_rule_sets() -> orp.RewriteRuleSet: diff --git a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py index 843ad920b..7372ef6cf 100644 --- a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py +++ b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py @@ -27,7 +27,7 @@ def group_normalization_and_silu_submodule( channels_last=1, epsilon=epsilon, groups=groups, - domain="com.microsoft", + _domain="com.microsoft", ) transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2]) return torch_module_op.submodule("torch_nn_modules_activation_SiLU")( @@ -51,7 +51,7 @@ def group_normalization_with_silu( channels_last=1, epsilon=epsilon, groups=groups, - domain="com.microsoft", + _domain="com.microsoft", ) return op.Transpose(group_norm, perm=[0, 3, 1, 2]) diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py index 8d5da54bb..46158550c 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py @@ -143,7 +143,7 @@ def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, ep channels_last=1, epsilon=epsilon, groups=groups, - domain="com.microsoft", + _domain="com.microsoft", ) return op.Transpose(output, perm=[0, 3, 1, 2]) diff --git a/onnxscript/rewriter/onnxruntime/softmax.py b/onnxscript/rewriter/onnxruntime/softmax.py index 12ad97672..f1d6df7b6 100644 --- a/onnxscript/rewriter/onnxruntime/softmax.py +++ b/onnxscript/rewriter/onnxruntime/softmax.py @@ -9,7 +9,6 @@ from onnxscript import ir from onnxscript.rewriter import pattern -op = pattern.onnxop logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 806ebc09e..b7f86dfce 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +import contextlib import dataclasses import inspect import itertools @@ -10,10 +11,9 @@ from typing import ( Any, Callable, + Iterable, Iterator, - List, MutableSequence, - Optional, Protocol, Sequence, Tuple, @@ -22,8 +22,8 @@ ) from onnxscript import ir -from onnxscript.ir import _convenience -from onnxscript.rewriter import _ir_utils, _tape +from onnxscript.ir import _convenience, _tape +from onnxscript.rewriter import _ir_utils T = TypeVar("T") @@ -34,7 +34,19 @@ class Pattern(Protocol[T]): # type: ignore[misc] def matches(self, item: T) -> bool: ... -class StringConstantPattern(Pattern[str]): +class StringPattern(abc.ABC, Pattern[str]): + """Abstract base class for string patterns.""" + + @abc.abstractmethod + def matches(self, item: str) -> bool: + pass + + @abc.abstractmethod + def __str__(self) -> str: + pass + + +class StringConstantPattern(StringPattern): """Matches strings with given value.""" def __init__(self, value: str): @@ -46,8 +58,11 @@ def matches(self, item: str) -> bool: def __str__(self) -> str: return self._value + def value(self) -> str: + return self._value + -class PrefixPattern(Pattern[str]): +class PrefixPattern(StringPattern): """Matches strings with a given prefix.""" def __init__(self, value: str) -> None: @@ -110,7 +125,7 @@ def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> """Represents promotion of values allowed as keyword-arguments in a pattern-builder call to an AttrPattern.""" if isinstance(value, AttrPattern): return value - if type(value) == ValuePattern: + if type(value) is ValuePattern: # This is a hack. Currently, when we create pattern-variables, we create them as ValuePattern, # and change them to AttrPattern if/when used in an attribute context. We could use type # annotations to distinguish between ValuePattern and AttrPattern, but forces users to @@ -128,8 +143,8 @@ def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> raise TypeError(f"Cannot convert {type(value)} to AttrPattern") -class OpsetPatternBuilder(Pattern[str]): - """Represents an opset pattern. +class OpsetPatternBuilder: + """Represents an opset pattern and a pattern builder. (i) It is used to create a NodePattern (via OpPatternBuilder). Example usage: @@ -140,24 +155,21 @@ class OpsetPatternBuilder(Pattern[str]): Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance of OpPatternBuilder, and `op.Matmul(x, y)` is an instance of NodePattern. - (ii) An opset pattern is also matched against the actual opset domain used in the + (ii) It contains a domain pattern matched against the actual opset domain used in the input model. """ - def __init__(self, domain: Pattern[str] | str) -> None: + def __init__(self, domain: StringPattern | str, record: bool = False) -> None: if isinstance(domain, str): - self._domain_name: str | None = domain - self._domain_pattern: Pattern[str] = StringConstantPattern(domain) + domain = StringConstantPattern(domain) + self._domain_pattern = domain + if record: + self._nodes: list[NodePattern] | None = [] else: - self._domain_name = None - self._domain_pattern = domain + self._nodes = None - @property - def domain_name(self) -> str | None: - return self._domain_name - - def matches(self, domain): - return self._domain_pattern.matches(domain) + def domain_pattern(self) -> StringPattern: + return self._domain_pattern def __getattr__(self, op_name: str) -> OpPatternBuilder: return OpPatternBuilder(self, op_name) @@ -169,10 +181,17 @@ def submodule(self, name: str) -> OpPatternBuilder: def __str__(self) -> str: return str(self._domain_pattern) + def add_node(self, node: NodePattern) -> None: + if self._nodes is not None: + self._nodes.append(node) + + def nodes(self) -> Sequence[NodePattern]: + if self._nodes is None: + raise ValueError("Nodes were not recorded.") + return self._nodes -onnxop = OpsetPatternBuilder("") -msft_op = OpsetPatternBuilder("com.microsoft") +onnxop = OpsetPatternBuilder("") torch_module_op = OpsetPatternBuilder(PrefixPattern("pkg.torch")) @@ -193,45 +212,46 @@ class OpPatternBuilder: def __init__( self, - opset_pattern: OpsetPatternBuilder, + pattern_builder: OpsetPatternBuilder, op_name: str | Pattern[str], ) -> None: - self.opset_pattern = opset_pattern + self.pattern_builder = pattern_builder self.op_name = op_name def __call__( self, *args, - domain: str | None = None, - version: int | None = None, - outputs: int | list[str | None] = 1, + _domain: str | None = None, + _version: int | None = None, + _outputs: int | list[str | None] = 1, _allow_other_attributes: bool | None = None, **kwargs, ): - if version is not None: + if _version is not None: raise ValueError( - "The pattern builder does not support 'version' keyword argument. " + "The pattern builder does not support '_version' keyword argument. " "Version restrictions should be handled by rewrite rules." ) - if domain is None: - opset_pattern = self.opset_pattern - elif isinstance(domain, str): - opset_pattern = OpsetPatternBuilder(domain) + if _domain is None: + opset_pattern = self.pattern_builder.domain_pattern() + elif isinstance(_domain, str): + opset_pattern = StringConstantPattern(_domain) else: - # TODO(rama): allow OpsetPatternBuilder as domain. - raise TypeError("domain must be a string.") + # TODO(rama): allow OpsetPatternBuilder as _domain. + raise TypeError("_domain must be a string.") - if isinstance(outputs, int): - outputs = [None for _ in range(outputs)] - elif not isinstance(outputs, Sequence) or not all( - isinstance(x, (str, type(None))) for x in outputs + if isinstance(_outputs, int): + _outputs = [None for _ in range(_outputs)] + elif not isinstance(_outputs, Sequence) or not all( + isinstance(x, (str, type(None))) for x in _outputs ): - raise ValueError("outputs must be an int or a list[str|None].") + raise ValueError("_outputs must be an int or a list[str|None].") inputs = [_to_value_pattern(x) for x in args] attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()} node_pattern = NodePattern( - opset_pattern, self.op_name, inputs, attributes, outputs, _allow_other_attributes + opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes ) + self.pattern_builder.add_node(node_pattern) output_values = node_pattern.outputs # Unpack outputs if there is only one output, the common case. if len(output_values) == 1: @@ -353,6 +373,18 @@ def extend(self, other: MatchResult | bool): self._matched_nodes.extend(other._matched_nodes) # type: ignore[attr-defined] +_pattern_builder: OpsetPatternBuilder = onnxop + + +@contextlib.contextmanager +def pattern_builder(builder: OpsetPatternBuilder): + global _pattern_builder + prev_builder = _pattern_builder + _pattern_builder = builder + yield + _pattern_builder = prev_builder + + class ValuePattern: """Base class for all patterns that match against IR values. @@ -365,6 +397,10 @@ def __init__(self, name: str | None) -> None: # Note: uses will be computed only when the full graph-pattern is constructed. self._uses: list[tuple[NodePattern, int]] = [] + def clone(self, node_map: dict[NodePattern, NodePattern]) -> ValuePattern: + del node_map + return ValuePattern(self._name) + @property def name(self) -> str | None: return self._name @@ -381,41 +417,32 @@ def append_use(self, node: NodePattern, index: int): def __repr__(self) -> str: return f"ValuePattern({self._name!r})" - def commute(self) -> Sequence[ValuePattern]: - """Return a list of commuted patterns. - - This is used to handle commutative operations like addition and multiplication. - A single pattern is converted into a list of equivalent patterns by swapping - the parameters of commutative operations. - """ - return [self] - def __add__(self, other): - return onnxop.Add(self, other) + return _pattern_builder.Add(self, other) def __radd__(self, other): - return onnxop.Add(other, self) + return _pattern_builder.Add(other, self) def __sub__(self, other): - return onnxop.Sub(self, other) + return _pattern_builder.Sub(self, other) def __rsub__(self, other): - return onnxop.Sub(other, self) + return _pattern_builder.Sub(other, self) def __mul__(self, other): - return onnxop.Mul(self, other) + return _pattern_builder.Mul(self, other) def __rmul__(self, other): - return onnxop.Mul(other, self) + return _pattern_builder.Mul(other, self) def __truediv__(self, other): - return onnxop.Div(self, other) + return _pattern_builder.Div(self, other) def __rtruediv__(self, other): - return onnxop.Div(other, self) + return _pattern_builder.Div(other, self) def __pow__(self, other): - return onnxop.Pow(self, other) + return _pattern_builder.Pow(self, other) def __str__(self) -> str: return self._name if self._name is not None else "anonymous:" + str(id(self)) @@ -440,7 +467,7 @@ class NodePattern: def __init__( self, - domain: OpsetPatternBuilder, + domain: StringPattern, op: str | Pattern[str], inputs: Sequence[int | float | ValuePattern | None], attributes: dict[str, AttrPattern], @@ -456,11 +483,11 @@ def __init__( self.attributes = attributes self.allow_other_attributes = allow_other_attributes # In the common case, domain and op are constants, which can be used to optimize matching. - if isinstance(op, str) and domain.domain_name is not None: + if isinstance(op, str) and isinstance(domain, StringConstantPattern): # TODO(rama): support overloaded operators. overload = "" self._op_identifier: tuple[str, str, str] | None = ( - domain.domain_name, + domain.value(), op, overload, ) @@ -522,36 +549,19 @@ def matches(self, node: ir.Node, match: MatchResult) -> MatchResult: return match - def commute(self) -> Sequence[NodePattern]: - list_of_lists = [ - [None] if pattern is None else pattern.commute() for pattern in self.inputs - ] # type: ignore[attr-defined] - - def enumerate_inputs(inputs, index): - if index >= len(inputs): - yield [] - else: - for pattern in inputs[index]: - for rest in enumerate_inputs(inputs, index + 1): - yield [pattern, *rest] - - inputs = list(enumerate_inputs(list_of_lists, 0)) - if self.domain.matches("") and (self.op.matches("Add") or self.op.matches("Mul")): - # TODO: handle cases where number of inputs is not 2. - swapped = [[x[1], x[0]] for x in inputs] - inputs.extend(swapped) + def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePattern: + inputs = [(v.clone(node_map) if v is not None else None) for v in self.inputs] + if swap: + assert ( + len(inputs) == 2 + ), "Internal error: commutative swap applies only to binary ops." + inputs = [inputs[1], inputs[0]] outputs = [value.name for value in self.outputs] - return [ - NodePattern( - self.domain, - self.op, - input, - self.attributes, - outputs, - self.allow_other_attributes, - ) - for input in inputs - ] + copied = NodePattern( + self.domain, self.op, inputs, self.attributes, outputs, self.allow_other_attributes + ) + node_map[self] = copied + return copied class NodeOutputPattern(ValuePattern): @@ -568,17 +578,14 @@ def __init__( self._producer = producer self._output_index = output_index + def clone(self, node_map: dict[NodePattern, NodePattern]) -> NodeOutputPattern: + return node_map[self._producer].outputs[self._output_index] + # return NodeOutputPattern(node_map[self._producer], self._output_index, self._name) + @property def output_index(self) -> int: return self._output_index - def commute(self) -> Sequence[ValuePattern]: - # TODO - return [ - NodeOutputPattern(pattern, self._output_index, self.name) - for pattern in self._producer.commute() - ] - def producer(self) -> NodePattern: return self._producer @@ -597,6 +604,10 @@ def __init__( self._rel_tol = rel_tol self._abs_tol = abs_tol + def clone(self, node_map: dict[NodePattern, NodePattern]) -> Constant: + del node_map + return Constant(self._value, self._rel_tol, self._abs_tol) + @property def value(self) -> int | float: return self._value @@ -628,9 +639,6 @@ def matches(self, value: ir.Value, match: MatchResult) -> MatchResult: # used elsewhere. return match - def commute(self) -> list[ValuePattern]: - return [self] - def __str__(self) -> str: return str(self._value) @@ -656,30 +664,37 @@ class GraphPattern: """Represents a pattern that can be matched against a subgraph.""" def __init__( - self, inputs: Sequence[ValuePattern], outputs: Sequence[ValuePattern] + self, + inputs: Sequence[ValuePattern], + outputs: Sequence[ValuePattern], + nodes: Sequence[NodePattern], ) -> None: self._inputs = inputs self._outputs = outputs if len(outputs) == 0: raise ValueError("GraphPattern must have at least one output") - self._nodes = _nodes_in_pattern(outputs) + self._nodes = nodes # _nodes_in_pattern(outputs) # Check if all outputs are produced by the same node. - output_node = None - for i, value_pattern in enumerate(outputs): + output_nodes: set[NodePattern] = set() + for value_pattern in outputs: if not isinstance(value_pattern, ValuePattern): raise TypeError( f"Invalid type {type(value_pattern)} for graph pattern output." ) - if not isinstance(value_pattern, NodeOutputPattern) or ( - value_pattern.output_index != i - ): - output_node = None - elif i == 0: - output_node = value_pattern.producer() - elif value_pattern.producer() is not output_node: - output_node = None - self._output_node = output_node + if isinstance(value_pattern, Constant): + raise NotImplementedError( + "Constant values are not allowed as graph pattern outputs." + ) + if isinstance(value_pattern, NodeOutputPattern): + output_nodes.add(value_pattern.producer()) + self.output_nodes: list[NodePattern] = list(output_nodes) + + @property + def output_node(self) -> NodePattern: + if len(self.output_nodes) != 1: + raise ValueError("GraphPattern does not have unique output node.") + return self.output_nodes[0] def node(self, index: int) -> NodePattern: return self._nodes[index] @@ -706,24 +721,40 @@ def __reversed__(self) -> Iterator[NodePattern]: @property def has_single_output_node(self) -> bool: - return self._output_node is not None + return len(self.output_nodes) == 1 @property def num_outputs(self) -> int: return len(self._outputs) def commute(self) -> Sequence[GraphPattern]: - if self._output_node is None: - raise NotImplementedError( - "Cannot commute a graph pattern with multiple output nodes." - ) - nodes = self._output_node.commute() - return [ - GraphPattern( - self._inputs, [NodeOutputPattern(n, i) for i in range(self.num_outputs)] - ) - for n in nodes - ] + def commute_node(node: NodePattern) -> Iterable[bool]: + if node.op_identifier() == ("", "Add", "") or node.op_identifier() == ( + "", + "Mul", + "", + ): + # Try with and without swapping inputs. + return [False, True] + # No swapping of inputs + return [False] + + iteration_space = [commute_node(node) for node in self._nodes] + + def copy_graph(swap_list: Iterable[bool]) -> GraphPattern: + if not any(swap_list): + # No need to swap inputs of any node + return self + # Create a copy of the graph, with swapped inputs for the nodes that need it. + node_map: dict[NodePattern, NodePattern] = {} + new_inputs = [v.clone(node_map) for v in self._inputs] + new_nodes = [ + node.clone(node_map, swap) for node, swap in zip(self._nodes, swap_list) + ] + new_outputs = [v.clone(node_map) for v in self._outputs] + return GraphPattern(new_inputs, new_outputs, new_nodes) + + return [copy_graph(swap_list) for swap_list in itertools.product(*iteration_space)] def __str__(self) -> str: inputs = ", ".join(str(v) for v in self._inputs) @@ -753,24 +784,29 @@ def pattern(op, x: Var, shape1: Var, shape2: Var): """ _pattern_vars = inspect.signature(pattern_constructor).parameters pattern_inputs = [Var(v) for v in _pattern_vars][1:] # Skip the first parameter - pattern_outputs = pattern_constructor(onnxop, *pattern_inputs) + builder = OpsetPatternBuilder("", record=True) + with pattern_builder(builder): + pattern_outputs = pattern_constructor(builder, *pattern_inputs) # TODO(rama): classify inputs as value/attribute vars # Returned value could be a single ValuePattern or a list of ValuePatterns. # Normalize representation to a list of ValuePatterns. if isinstance(pattern_outputs, ValuePattern): pattern_outputs = [pattern_outputs] - return GraphPattern(pattern_inputs, pattern_outputs) + return GraphPattern(pattern_inputs, pattern_outputs, builder.nodes()) -def _valid_to_replace(matched_nodes: Sequence[ir.Node]) -> bool: - """Check that values computed by the matched_nodes, except for the last one, are used only by the matched_nodes.""" +def _valid_to_replace( + matched_nodes: Sequence[ir.Node], output_values: Sequence[ir.Value] +) -> bool: + """Check that values computed by the matched_nodes, except for output_values, are used only by the matched_nodes.""" # * Must check that all values matched by pattern are used only by pattern, # except for the value that is replaced. # * Must ensure that replacement subgraph does not use any of the deleted # (intermediate) values. (Not necessary for now. Guaranteed.) - deleted_nodes = matched_nodes[:-1] - for n in deleted_nodes: + for n in matched_nodes: for v in n.outputs: + if v in output_values: + continue if v.is_graph_output(): # value is an output-value of the graph/function. return False @@ -780,58 +816,7 @@ def _valid_to_replace(matched_nodes: Sequence[ir.Node]) -> bool: return True -# A type representing the domains/versions used in creating a replacement subgraph -UsedOpsets = List[Tuple[str, Optional[int]]] - - -class RewriterContext: - """Context parameter used to build the replacement pattern.""" - - # TODO(justinchuby): Merge with the rest of pattern building methods - def __init__(self): - self._tape = _tape.Tape() - self._used_opsets: UsedOpsets = [] - - def __getattr__(self, op_type: str) -> Any: - return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) - - def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): - # TODO(rama): some of the following logic should move into the tape. - domain = kwargs.pop("domain", "") - version = kwargs.pop("version", None) - outputs = kwargs.pop("outputs", 1) - if isinstance(outputs, Sequence): - num_outputs = len(outputs) - else: - assert isinstance(outputs, int) - num_outputs = outputs - - self._used_opsets.append((domain, version)) - if num_outputs == 1: - value = self._tape.op(op_type, inputs=inputs, attributes=kwargs, domain=domain) - if isinstance(outputs, Sequence): - value.name = outputs[0] - return value - values = self._tape.op_multi_output( - op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs - ) - if isinstance(outputs, Sequence): - for value, name in zip(values, outputs): - value.name = name - return values - - @property - def nodes(self) -> Sequence[ir.Node]: - # TODO(rama): The current tape-based implementation will not track nodes added - # via overloaded operators, eg., `x + y`. One possible way to fix this is to - # have values/nodes know which tape they belong to (instead of a graph/function). - # However, it is unclear we need this feature for rewriting: we could also - # identify the nodes to be inserted from the replacement values (by tracing back). - return self._tape.nodes - - @property - def used_opsets(self) -> UsedOpsets: - return self._used_opsets +RewriterContext = _tape.Builder @dataclasses.dataclass @@ -841,7 +826,7 @@ class ReplacementSubgraph: match: MatchResult new_outputs: Sequence[ir.Value] new_nodes: Sequence[ir.Node] - used_opsets: UsedOpsets + used_opsets: _tape.UsedOpsets def always_true(*args, **kwargs) -> bool: @@ -899,7 +884,7 @@ def match( node: ir.Node, verbose: int = 0, ) -> MatchResult: - pass + """Match the pattern against the subgraph ending at the given node.""" def __str__(self) -> str: return str(self.pattern) @@ -907,9 +892,6 @@ def __str__(self) -> str: class SimplePatternMatcher(PatternMatcher): def __init__(self, pattern: GraphPattern) -> None: - assert ( - pattern.has_single_output_node - ), "SimplePatternMatcher only supports patterns with a single output node." super().__init__(pattern) def fail(self, reason: str) -> bool: @@ -1029,37 +1011,152 @@ def _match_node_output(self, pattern_value: NodeOutputPattern, value: ir.Value) ) return self._match_node(pattern_value.producer(), node) - def match( + def _init_match(self, verbose: int) -> None: + """Initialize the match state. Invoked before starting a new match.""" + self._verbose = verbose + self._matched: dict[NodePattern, ir.Node] = {} + self._match: MatchResult = MatchResult() + + def _get_output_values(self) -> list[ir.Value] | None: + """Get values bound to the output variables of the pattern.""" + output_values: list[ir.Value] = [] + unbound_values: list[str] = [] + for j, value_pattern in enumerate(self.pattern.outputs): + if value_pattern.name is not None: + if value_pattern.name in self._match.bindings: + output_values.append(self._match.bindings[value_pattern.name]) + else: + unbound_values.append(value_pattern.name) + elif isinstance(value_pattern, NodeOutputPattern): + i = value_pattern.output_index + node = value_pattern.producer() + if node in self._matched: + output_values.append(self._matched[node].outputs[i]) + else: + unbound_values.append(f"output_{j}") + elif isinstance(value_pattern, Constant): + raise NotImplementedError("Constant values as return-values not supported.") + if unbound_values: + self._match.fail(f"Error: Output values not found: {unbound_values}") + return None + return output_values + + def _match_single_output_node( self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, - verbose: int = 0, ) -> MatchResult: del model del graph_or_function - self._verbose = verbose - self._matched: dict[NodePattern, ir.Node] = {} - self._match: MatchResult = MatchResult() pattern = self.pattern match = self._match - if len(node.outputs) != pattern.num_outputs: + + if not pattern.has_single_output_node: return match.fail( - f"Number of node outputs mismatch: expected {pattern.num_outputs}, got {len(node.outputs)}." + "Internal Error: SimplePatternMatcher should not be used for patterns with multiple output nodes." ) - if pattern._output_node is None: + + if not self._match_node(pattern.output_node, node): + return match + + output_values = self._get_output_values() + if output_values is None: + return match + if not _valid_to_replace(match.nodes, output_values): + return match.fail("Matched nodes have other uses preventing replacement.") + + if len(node.outputs) != pattern.num_outputs: return match.fail( - "Internal Error: SimplePatternMatcher should not be used for patterns with multiple output nodes." + f"Number of node outputs mismatch: expected {pattern.num_outputs}, got {len(node.outputs)}." ) - if self._match_node(pattern._output_node, node): - if not _valid_to_replace(match.nodes): - return match.fail("Matched nodes have other uses preventing replacement.") + match.outputs.extend(output_values) + return match + + def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: + """Find a match for a pattern with multiple output nodes. + + For a pattern with K output nodes, the input candidate should specify K nodes + in the graph that will be matched against the pattern output nodes. + + Args: + candidate: An iterable of nodes that will be matched against the pattern output nodes. + """ + match = self._match + for pattern_node, node in zip(self.pattern.output_nodes, candidate): + if not self._match_node(pattern_node, node): + return match + output_values = self._get_output_values() + if output_values is None: + return match + + if not _valid_to_replace(match.nodes, output_values): + return match.fail("Matched nodes have other uses preventing replacement.") - match.outputs.extend(node.outputs) + match.outputs.extend(output_values) return match + def match( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + node: ir.Node, + verbose: int = 0, + ) -> MatchResult: + """Match the pattern against the subgraph ending at the given node. + + For patterns with multiple output nodes, the given node is matched + against the first output node in the pattern. For the remaining + output nodes in the pattern, we use a brute-force algorithm that + enumerates all possible combinations of nodes from the graph (with + a filter based on op-type). + + TODO: Consider omitting parameters model and graph_or_function. With + the new IR, the graph can be obtained from the node, and the model is + not used. But this is a shared abstract method of the Matcher interface, + so other matcher implementation also needs to be updated. More importantly, + matching in the presence of subgraphs (control-flow) can introduce some + complications which require careful consideration. + """ + + if self.pattern.has_single_output_node: + self._init_match(verbose) + return self._match_single_output_node(model, graph_or_function, node) + else: + # Note: This is a potentially expensive algorithm for matching patterns with + # multiple output nodes. For patterns with N output nodes, we try all possible + # combinations of N nodes from the graph, and check if they match the pattern. + # The first node is fixed to the node argument in this method call. We do + # some simple filtering by restricting the candidates for each remaining + # output nodes to graph nodes with the same op_type as the corresponding pattern + # node. For now, this is intended to be a simple, but robust, implementation + # that can be used for debugging and testing. The GenericPatternMatcher is a + # more sophisticated implementation, but incomplete. + pattern_output_nodes = self.pattern.output_nodes + op_to_nodes: dict[tuple[str, str, str], list[ir.Node]] = {} + for n in graph_or_function: + op_to_nodes.setdefault(n.op_identifier(), []).append(n) + all_nodes = iter(graph_or_function) + + def get_nodes(pattern_node): + id = pattern_node.op_identifier() + if id is None: + return all_nodes + return op_to_nodes.get(id, []) + + candidates = [iter([node])] + [get_nodes(pn) for pn in pattern_output_nodes[1:]] + match = None + for combination in itertools.product(*candidates): + self._init_match(verbose) + match = self._multi_match(combination) + if match: + return match + if match is None: + return MatchResult().fail("No match found.") + return match + class RewriteRule: def __init__( @@ -1238,58 +1335,6 @@ def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None): ) -def _apply_delta( - graph_or_function: ir.Graph | ir.Function, - node: ir.Node, - delta: ReplacementSubgraph, -): - """Applies delta. - - This code is valid is the considered pattern has only one output. - In case of multi output replacements, there is not need to rename - the outputs. - - In case of multi-output design, the nodes may not be necessary inserted - all at the same position. To be convinced, you can take a pattern - producing two outputs, but the second one needs the first one and - another input appeared after the first outputs. What could be - the right place to inserted all of the node. - - The current implementation insert all the nodes at the same position - but checks there is not inconsistency. In that case, it fails. - We could reorder (long) or do more clever changes. - The reordering would probably happen not very often. - """ - - assert isinstance(delta, ReplacementSubgraph) - # Replace matched nodes with new nodes, matched values with new values - old_values = delta.match.outputs - new_values = delta.new_outputs - - for old_value, new_value in zip(old_values, new_values): - # Propagate relevant info from old value to new value - # TODO(Rama): Perhaps we should merge old and new types. As of now, new - # values don't have type information. Note that this could be a problem - # for semantics-altering rewrite-rules: we should allow users to override - # this for such rules. - new_value.type = old_value.type - new_value.shape = old_value.shape - new_value.const_value = old_value.const_value - new_value.name = old_value.name - - # Reconnect the users of the deleted node to use the new outputs - _convenience.replace_all_uses_with(old_values, new_values) - # Update graph/function outputs if the node generates output - replacement_mapping = dict(zip(old_values, new_values)) - for idx, graph_or_function_output in enumerate(graph_or_function.outputs): - if graph_or_function_output in replacement_mapping: - graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] - - # insert new nodes after the index node - graph_or_function.insert_after(node, delta.new_nodes) - graph_or_function.remove(delta.match.nodes, safe=True) - - class RewriteRuleSet: def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: if commute: @@ -1311,7 +1356,19 @@ def _apply_to_graph_or_function( delta = rule.try_rewrite(model, graph_or_function, node, verbose=verbose) if delta is None: continue - _apply_delta(graph_or_function, node, delta) + assert isinstance(delta, ReplacementSubgraph) + # TODO: This does not yet handle the problem of determining the correct insertion point + # for inserted nodes in the case of patterns with multiple output-nodes. The following + # is sufficient for patterns with a single output-node "node", which can serve as the + # insertion-point. + _convenience.replace_nodes_and_values( + graph_or_function, + node, + delta.match.nodes, + delta.new_nodes, + delta.match.outputs, + delta.new_outputs, + ) count += 1 return count diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 0b2748b1d..5385a5233 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -109,7 +109,7 @@ def fast_gelu_pattern1(op, x): return (1.0 + tanh) * (0.5 * x) def fast_gelu(op, x): - return op.FastGelu(x, domain="com.microsoft") + return op.FastGelu(x, _domain="com.microsoft") return pattern.RewriteRule(fast_gelu_pattern1, fast_gelu) @@ -130,7 +130,7 @@ def fast_gelu_pattern1_long(op, x): return op.Mul(one_plus_tanh, half_x) def fast_gelu(op, x): - return op.FastGelu(x, domain="com.microsoft") + return op.FastGelu(x, _domain="com.microsoft") return pattern.RewriteRule(fast_gelu_pattern1_long, fast_gelu) @@ -315,7 +315,7 @@ def add_same(op, x): return x + x def double(op, x): - return op.Double(x, domain="custom.domain", version=10) + return op.Double(x, _domain="custom.domain", _version=10) rule = pattern.RewriteRule(add_same, double) @@ -339,7 +339,7 @@ def add_same(op, x): return x + x def double(op, x): - return op.Double(x, domain="custom.domain", version=10) + return op.Double(x, _domain="custom.domain", _version=10) rule = pattern.RewriteRule(add_same, double) @@ -373,7 +373,7 @@ def test_optional_attribute(self): def concat_pattern(op, x, y): seq = op.SequenceConstruct(x, y) - result = op.ConcatFromSequence(seq, outputs=["result"]) + result = op.ConcatFromSequence(seq, _outputs=["result"]) return result def concat(op, x, y, result: ir.Value): @@ -421,5 +421,18 @@ def concat(op, x, y, result: ir.Value): self.assertNotIn("axis", model.graph[0].attributes) +class PatternBuilderTest(unittest.TestCase): + def test_pattern_builder_context(self): + builder = pattern.OpsetPatternBuilder("", True) + with pattern.pattern_builder(builder): + x = builder.Op1() + y = builder.Op2(x) + z = x + y + w = builder.Op3(z) + _ = z * w + ops = [x.op_type for x in builder.nodes()] + self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"]) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/testing/__init__.py b/onnxscript/testing/__init__.py index c731f6e95..f7bb74980 100644 --- a/onnxscript/testing/__init__.py +++ b/onnxscript/testing/__init__.py @@ -389,7 +389,7 @@ def assert_onnx_proto_equal( a: The first ONNX proto. b: The second ONNX proto. """ - assert type(a) == type(b), f"Type not equal: {type(a)} != {type(b)}" # pylint: disable=unidiomatic-typecheck + assert type(a) is type(b), f"Type not equal: {type(a)} != {type(b)}" a_fields = {field.name: value for field, value in a.ListFields()} b_fields = {field.name: value for field, value in b.ListFields()} diff --git a/onnxscript/tools/benchmark/benchmark_helpers.py b/onnxscript/tools/benchmark/benchmark_helpers.py index e796a8808..3a874fa46 100644 --- a/onnxscript/tools/benchmark/benchmark_helpers.py +++ b/onnxscript/tools/benchmark/benchmark_helpers.py @@ -287,7 +287,7 @@ def common_export( if exporter == "script": torch.onnx.export( model, - inputs, + inputs, # type: ignore[arg-type] filename, do_constant_folding=False, input_names=[f"input{i}" for i in range(len(inputs))], diff --git a/onnxscript/tools/transformers_models/__init__.py b/onnxscript/tools/transformers_models/__init__.py index fd7a5807a..43dc81e9b 100644 --- a/onnxscript/tools/transformers_models/__init__.py +++ b/onnxscript/tools/transformers_models/__init__.py @@ -41,6 +41,7 @@ def export_to_onnx( prog = torch.onnx.export(model, args, dynamo=True) # pylint: disable=no-value-for-parameter else: prog = torch.onnx.dynamo_export(model, *args) + assert prog is not None model_proto = prog.model_proto if optimize: model_proto = onnxscript.optimizer.optimize( diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index ebca264fc..1acb0d4f4 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.5.1 +ruff==0.5.6 # MYPY mypy==1.10.1 types-PyYAML==6.0.12.11 diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 8cb245908..b4f0cc40c 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -39,7 +39,6 @@ import copy import dataclasses import functools -import sys from typing import Any, Callable, Collection, Optional import numpy as np @@ -284,6 +283,35 @@ def _grid_sample_input_wrangler( return args, kwargs +def _im2col_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # Move kernel_size, dilation, padding and stride from args to kwargs + if len(args) == 5: + # Handle stride + stride = args.pop() + if isinstance(stride, np.ndarray): # convert stride to list[int] + stride = stride.tolist() + kwargs["stride"] = stride + # Handle padding + padding = args.pop() + if isinstance(padding, np.ndarray): # convert padding to list[int] + padding = padding.tolist() + kwargs["padding"] = padding + # Handle dilation + dilation = args.pop() + if isinstance(dilation, np.ndarray): # convert dilation to list[int] + dilation = dilation.tolist() + kwargs["dilation"] = dilation + # Handle kernel_size + kernel_size = args.pop() + if isinstance(kernel_size, np.ndarray): # convert kernel_size to list[int] + kernel_size = kernel_size.tolist() + kwargs["kernel_size"] = kernel_size + + return args, kwargs + + def _linalg_vector_norm_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -713,19 +741,25 @@ def _where_input_wrangler( dtypes=(torch.bool,), reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", ), - TorchLibOpInfo("clamp_max", core_ops.aten_clamp).skip( - enabled_if=sys.version_info[:2] >= (3, 9) or sys.platform != "win32", - reason="fails in this particular case", - ), - TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max).skip( + TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max) + .skip( matcher=lambda sample: len(sample.input.shape) == 0, enabled_if=version_utils.onnxruntime_older_than("1.16"), reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", + ) + .skip( + reason="Size 0 inputs are not handled by design", + matcher=lambda sample: sample.input.numel() == 0, ), - TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min).skip( + TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min) + .skip( matcher=lambda sample: len(sample.input.shape) == 0, enabled_if=version_utils.onnxruntime_older_than("1.16"), reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", + ) + .skip( + reason="Size 0 inputs are not handled by design", + matcher=lambda sample: sample.input.numel() == 0, ), TorchLibOpInfo("clone", core_ops.aten_clone), TorchLibOpInfo("complex", core_ops.aten_complex), @@ -1054,39 +1088,6 @@ def _where_input_wrangler( "new_zeros", core_ops.aten_new_zeros, ), - TorchLibOpInfo( - "nn.functional.adaptive_avg_pool1d", - nn_ops.aten_adaptive_avg_pool1d, - ) - .xfail( - # Shape should be [N, C, D1] - matcher=lambda sample: sample.args[0] not in {1, (1,)}, - reason="only global pooling is supported; only batched inputs are supported", - ) - .xfail( - reason="ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449", - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", - ), - TorchLibOpInfo( - "nn.functional.adaptive_avg_pool2d", - nn_ops.aten_adaptive_avg_pool2d, - ).xfail( - matcher=lambda sample: sample.args[0] != (1, 1), - reason="only global pooling is supported; only batched inputs are supported", - ), - TorchLibOpInfo( - "nn.functional.adaptive_avg_pool3d", - nn_ops.aten_adaptive_avg_pool3d, - ) - .xfail( - matcher=lambda sample: sample.args[0] != (1, 1, 1), - reason="only global pooling is supported; only batched inputs are supported", - ) - .xfail( - dtypes=(torch.float16,), - reason="fixme: RuntimeError: ORT inference error GlobalAveragePool. https://github.com/microsoft/onnxruntime/issues/16449", - ), TorchLibOpInfo("nn.functional.celu", nn_ops.aten_celu), TorchLibOpInfo("nn.functional.celu_type_promoted", nn_ops.aten_celu_type_promoted), TorchLibOpInfo( @@ -1142,22 +1143,11 @@ def _where_input_wrangler( tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (8e-2, 4e-4)}, ), TorchLibOpInfo("nn.functional.mish", nn_ops.aten_mish), - TorchLibOpInfo( - "nn.functional.nll_loss_weight", - nn_ops.aten_nll_loss_weight, - tolerance={torch.float16: (5e-2, 1e-2)}, - input_wrangler=_nll_loss_input_wrangler, - ).skip( - matcher=lambda sample: "weight" not in sample.kwargs, - reason="this Aten overload need weight as kwargs", - ), TorchLibOpInfo( "nn.functional.nll_loss", nn_ops.aten_nll_loss, input_wrangler=_nll_loss_input_wrangler, - ).skip( - matcher=lambda sample: "weight" in sample.kwargs, - reason="this Aten overload doesn't accept weight as kwargs", + tolerance={torch.float16: (5e-2, 1e-2)}, ), TorchLibOpInfo( "nn.functional.pixel_shuffle", @@ -1501,6 +1491,33 @@ def _where_input_wrangler( ), TorchLibOpInfo("stack", core_ops.aten_stack), TorchLibOpInfo("stack", core_ops.aten_stack_complex, complex=True), + TorchLibOpInfo( + "std_mean", + core_ops.aten_std_mean, + ).xfail( + # kwargs is empty + matcher=lambda sample: len(sample.kwargs) > 0, + reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs", + ), + TorchLibOpInfo( + "std_mean_dim", + core_ops.aten_std_mean_dim, + ).xfail( + # kwargs["dim"] must exist, kwargs["correction"] must not exist + matcher=lambda sample: not ( + sample.kwargs.get("dim", None) is not None + and sample.kwargs.get("correction", None) is None + ), + reason="this Aten overload only support with 'dim' argument and without 'correction' argument", + ), + TorchLibOpInfo( + "std_mean_correction", + core_ops.aten_std_mean_correction, + ).skip( + # Don't accept input[1]=bool and 'correction' must be in kwargs + matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, + reason="this Aten overload only support when correction attribute exists", + ), TorchLibOpInfo("sub", core_ops.aten_sub), TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB @@ -1584,16 +1601,28 @@ def _where_input_wrangler( TorchLibOpInfo( "arange_start_step", core_ops.aten_arange_start_step, - ).xfail( + ) + .skip( matcher=lambda sample: len(sample.args) != 2, reason="arange_start_step overload takes three arguments (input, start, step)", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dtype") is None, + reason="dtype needs to be specified for non-float tensors", + dtypes=(torch.float16, torch.int64, torch.int32), ), TorchLibOpInfo( "arange_start", core_ops.aten_arange_start, - ).skip( + ) + .skip( matcher=lambda sample: len(sample.args) != 1, reason="arange_start overload takes two arguments (input, start)", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dtype") is None, + reason="dtype needs to be specified for non-float tensors", + dtypes=(torch.float16, torch.int64, torch.int32), ), TorchLibOpInfo( "arange", @@ -1603,13 +1632,18 @@ def _where_input_wrangler( dtypes=(torch.int32,), reason="fixme: output shape mismatch in edge cases. https://github.com/microsoft/onnxscript/issues/974", ) - .xfail( + .skip( matcher=lambda sample: len(sample.args) != 0, reason="arange overload takes single argument", ) .xfail( matcher=lambda sample: sample.kwargs.get("end") is not None, reason="arange overload does not support positional 'end' argument", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dtype") is None, + reason="dtype needs to be specified for non-float tensors", + dtypes=(torch.float16, torch.int64, torch.int32), ), TorchLibOpInfo("argmax", core_ops.aten_argmax) .skip( @@ -1669,6 +1703,14 @@ def _where_input_wrangler( matcher=lambda sample: sample.args[1] == 2, reason="fixme: 'bicubic' mode in ORT implemented differently with Torch", ), + TorchLibOpInfo( + "nn.functional.group_norm", + nn_ops.aten_group_norm, + tolerance={torch.float16: (1e-2, 7e-3)}, + ).xfail( + matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), + reason="Using op.InstanceNormalization to simulate GroupNorm, which does not support 0-dim input", + ), TorchLibOpInfo("heaviside", core_ops.aten_heaviside), TorchLibOpInfo( "hstack", @@ -1888,6 +1930,15 @@ def _where_input_wrangler( tolerance={torch.float16: (8e-2, 1e-4)}, ), TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), + TorchLibOpInfo( + "nn.functional.unfold", + nn_ops.aten_im2col, + input_wrangler=_im2col_input_wrangler, + ).xfail( + matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) + or not sample.input.shape, + reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", + ), TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear).skip( # input: input, args: weight, bias; so len(args) == 2 means bias is provided matcher=lambda sample: len(sample.args) != 1, @@ -2146,6 +2197,33 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="RuntimeError: MKL FFT doesn't support tensors of type: Half", ), + TorchLibOpInfo( + "std", + core_ops.aten_std, + ).xfail( + # kwargs must be empty + matcher=lambda sample: len(sample.kwargs) > 0, + reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs", + ), + TorchLibOpInfo( + "std_dim", + core_ops.aten_std_dim, + ).xfail( + # kwargs["dim"] must exist, kwargs["correction"] must not exist + matcher=lambda sample: not ( + sample.kwargs.get("dim", None) is not None + and sample.kwargs.get("correction", None) is None + ), + reason="this Aten overload only support with 'dim' argument and without 'correction' argument", + ), + TorchLibOpInfo( + "std_correction", + core_ops.aten_std_correction, + ).skip( + # Don't accept input[1]=bool and 'correction' must be in kwargs + matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, + reason="this Aten overload only support when correction attribute exists", + ), TorchLibOpInfo( "sum", core_ops.aten_sum_dim_IntList, @@ -2267,9 +2345,6 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.linear", ("nn.functional.linear_bias",) ) -ops_test_common.duplicate_opinfo( - OPS_DB, "nn.functional.nll_loss", ("nn.functional.nll_loss_weight",) -) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.pad", @@ -2295,6 +2370,8 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) +ops_test_common.duplicate_opinfo(OPS_DB, "std_mean", ("std_mean_dim", "std_mean_correction")) +ops_test_common.duplicate_opinfo(OPS_DB, "std", ("std_dim", "std_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "var", ("var_dim", "var_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",)) diff --git a/tools/optimize.py b/tools/optimize.py new file mode 100644 index 000000000..276cda890 --- /dev/null +++ b/tools/optimize.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Utility for optimizing ONNX models. + +Usage: + python optimize.py model.onnx optimized_model.onnx +""" + +import argparse +import os + +import onnx +import onnx.inliner + +import onnxscript + + +def main(args) -> None: + path = args.path + output_path = args.output_path + + model = onnx.load(path, load_external_data=False) + # Hack: Change the working directory to the model directory so the optimizer + # can load external data files with relative paths. + # TODO: Remove this hack by fixing the optimizer to handle external data files properly. + pwd = os.getcwd() + model_dir = os.path.dirname(path) + os.chdir(model_dir) + model = onnxscript.optimizer.optimize(model) + model = onnx.inliner.inline_local_functions(model) + # Optimize again in case inlining created new opportunities. + model = onnxscript.optimizer.optimize(model) + + os.chdir(pwd) + onnx.save(model, output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Optimize an ONNX model.") + parser.add_argument("path", type=str, help="Path to the ONNX model.") + parser.add_argument("output_path", type=str, help="Path to save the optimized model.") + main(parser.parse_args())