Skip to content

Use ir methods to replace onnx helper #2091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
33 changes: 7 additions & 26 deletions onnxscript/_internal/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnx.defs import OpSchema

from onnxscript import tensor
from onnxscript import ir, tensor

if TYPE_CHECKING:
from onnxscript import converter
Expand All @@ -37,29 +36,7 @@


def pyvalue_to_onnx_tensor(tensor_name: str, pyvalue):
if isinstance(pyvalue, np.ndarray):
return numpy_helper.from_array(pyvalue, tensor_name)
if isinstance(pyvalue, list):
if len(pyvalue) == 0:
raise ValueError("Cannot convert an empty list to tensor")
pytype = type(pyvalue[0])
if not all(isinstance(e, pytype) for e in pyvalue):
raise ValueError(
"Cannot convert an list with elements of different types to tensor"
)
return helper.make_tensor(
tensor_name,
_py_type_to_onnx_type(pytype),
[len(pyvalue)],
pyvalue,
)
onnx_type = _py_type_to_onnx_type(type(pyvalue))
if onnx_type is onnx.TensorProto.BOOL:
return helper.make_tensor(tensor_name, onnx_type, [], [int(pyvalue)])
if onnx_type is onnx.TensorProto.STRING:
return helper.make_tensor(tensor_name, onnx_type, [], vals=[pyvalue.encode("utf-8")])

return helper.make_tensor(tensor_name, onnx_type, [], [pyvalue])
return ir.serde.serialize_tensor(ir.tensor(pyvalue, name=tensor_name))


_REPEATED_ATTRIBUTE_TYPES = frozenset(
Expand Down Expand Up @@ -103,7 +80,11 @@
name=key, type=attr_type, t=pyvalue_to_onnx_tensor(name_generator(), value)
)
else:
return onnx.helper.make_attribute(key, value)
attr = ir.convenience.convert_attribute(
key, value, attr_type=ir.AttributeType(attr_type)
)
assert isinstance(attr, ir.Attr)
return ir.serde.serialize_attribute(attr)


# Utilities to convert python values into onnxscript tensors.
Expand Down
21 changes: 10 additions & 11 deletions onnxscript/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import numpy as np
import onnx
import onnx.helper

from onnxscript import tensor

Expand Down Expand Up @@ -65,26 +64,26 @@
def value_to_type_proto(val):
"""Return the ONNX type of a python-value."""
if isinstance(val, (np.ndarray, tensor.Tensor)):
elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype)
elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) # noqa: TID251
shape = val.shape
return onnx.helper.make_tensor_type_proto(elem_type, shape)
return onnx.helper.make_tensor_type_proto(elem_type, shape) # noqa: TID251
if isinstance(val, int):
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, [])
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) # noqa: TID251

Check warning on line 71 in onnxscript/_internal/utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_internal/utils.py#L71

Added line #L71 was not covered by tests
if isinstance(val, (float, np.float32)):
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [])
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) # noqa: TID251

Check warning on line 73 in onnxscript/_internal/utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_internal/utils.py#L73

Added line #L73 was not covered by tests
if isinstance(val, list):
if len(val) > 0:
return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0]))
return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) # noqa: TID251

Check warning on line 76 in onnxscript/_internal/utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_internal/utils.py#L76

Added line #L76 was not covered by tests
# Edge-case. Cannot determine a suitable ONNX type for an empty list.
# Should be using a typed-value instead.
# Treated as a sequence of tensors of float-type.
return onnx.helper.make_sequence_type_proto(
onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None)
return onnx.helper.make_sequence_type_proto( # noqa: TID251

Check warning on line 80 in onnxscript/_internal/utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_internal/utils.py#L80

Added line #L80 was not covered by tests
onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None) # noqa: TID251
)
if isinstance(val, numbers.Number):
nparray = np.array(val)
elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype)
return onnx.helper.make_tensor_type_proto(elem_type, [])
elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) # noqa: TID251
return onnx.helper.make_tensor_type_proto(elem_type, []) # noqa: TID251

Check warning on line 86 in onnxscript/_internal/utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_internal/utils.py#L85-L86

Added lines #L85 - L86 were not covered by tests
raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")


Expand All @@ -93,7 +92,7 @@
skipping any None values.
"""
return [
onnx.helper.make_value_info(name, value_to_type_proto(val))
onnx.helper.make_value_info(name, value_to_type_proto(val)) # noqa: TID251
for (name, val) in name_values
if val is not None
]
2 changes: 1 addition & 1 deletion onnxscript/_legacy_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
if isinstance(self.value, np.ndarray):
return self.value
if isinstance(self.value, onnx.TensorProto):
return onnx.numpy_helper.to_array(self.value)
return onnx.numpy_helper.to_array(self.value) # noqa: TID251

Check warning on line 145 in onnxscript/_legacy_ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_legacy_ir/__init__.py#L145

Added line #L145 was not covered by tests
return None

def def_node(self) -> Node | None:
Expand Down
1 change: 1 addition & 0 deletions onnxscript/_legacy_ir/visitor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: TID251
from __future__ import annotations

import dataclasses
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/backend/onnx_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# ruff: noqa: TID251

import os
import textwrap
Expand Down
9 changes: 4 additions & 5 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy
import onnx
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto
from onnx.helper import make_node

import onnxscript.onnx_types
import onnxscript.type_annotation
Expand Down Expand Up @@ -68,10 +67,10 @@
if tensor_proto.data_type in {TensorProto.FLOAT, TensorProto.INT64}:
rank = len(tensor_proto.dims)
if rank == 0:
array = onnx.numpy_helper.to_array(tensor_proto).reshape(1)
array = onnx.numpy_helper.to_array(tensor_proto).reshape(1) # noqa: TID251

Check warning on line 70 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L70

Added line #L70 was not covered by tests
return repr(array[0])
if rank == 1 and tensor_proto.dims[0] < 5:
return repr(list(onnx.numpy_helper.to_array(tensor_proto)))
return repr(list(onnx.numpy_helper.to_array(tensor_proto))) # noqa: TID251

Check warning on line 73 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L73

Added line #L73 was not covered by tests
return None


Expand Down Expand Up @@ -161,7 +160,7 @@
if onnx.external_data_helper.uses_external_data(tensor_proto):
return tensor_proto
else:
return onnx.numpy_helper.to_array(tensor_proto)
return onnx.numpy_helper.to_array(tensor_proto) # noqa: TID251

Check warning on line 163 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L163

Added line #L163 was not covered by tests
# TODO:
# - onnx.AttributeProto.GRAPH
# - onnx.AttributeProto.SPARSE_TENSOR
Expand Down Expand Up @@ -348,7 +347,7 @@
)
self.skipped_initializers[init_py_name] = init
continue
node = make_node(
node = onnx.helper.make_node( # noqa: TID251

Check warning on line 350 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L350

Added line #L350 was not covered by tests
"Constant",
[],
[self._translate_onnx_var(init.name)], # type: ignore[list-item]
Expand Down
13 changes: 7 additions & 6 deletions onnxscript/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
import onnx
import onnx.defs
import onnx.helper
import onnx.helper # noqa: TID251
import onnx.reference
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -430,21 +430,22 @@ def make_tensor_name() -> str:
num_outputs = compute_num_outputs(schema, args, kwargs)
outputs = [f"output{i}" for i in range(num_outputs)]

node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain)
node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain) # noqa: TID251
node.attribute.extend(
make_attr(key, value) for key, value in kwargs.items() if value is not None
)
input_value_infos = utils.values_to_value_infos(zip(inputs, args))
implicit_value_infos = utils.values_to_value_infos(implicit_args.items())
output_value_infos = [
onnx.helper.make_value_info(name, onnx.TypeProto()) for name in outputs
onnx.helper.make_value_info(name, onnx.TypeProto()) # noqa: TID251
for name in outputs
]

graph = onnx.helper.make_graph(
graph = onnx.helper.make_graph( # noqa: TID251
[node], "node_graph", input_value_infos + implicit_value_infos, output_value_infos
)
opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version)
model = onnx.helper.make_model(
opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version) # noqa: TID251
model = onnx.helper.make_model( # noqa: TID251
graph,
opset_imports=[opset_id],
ir_version=irbuilder.select_ir_version(schema.since_version, domain=schema.domain),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: TID251
"""Graph building functions for torchscript graph backend."""

from __future__ import annotations
Expand Down
19 changes: 5 additions & 14 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import math
from typing import Optional, Sequence, Tuple, TypeVar, Union

import onnx

from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
Expand Down Expand Up @@ -1800,15 +1798,11 @@ def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs(
op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3])
)
logsumexp = op.Expand(0.0, query_first_three_dims)
# TODO: Eliminate `make_tensor` usage when ORT supports empty tensor.
empty_tensor_int = op.Cast(
op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
),
to=INT64.dtype,
empty_tensor_int = op.ConstantOfShape(
op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64))
)
empty_tensor_float = op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_FLOATS", INT64.dtype, [0], []))
op.Constant(value=ir.tensor([], dtype=ir.DataType.FLOAT))
)
empty_int = op.Constant(value_int=0)

Expand Down Expand Up @@ -1883,11 +1877,8 @@ def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs(
logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, [0], axis=0))

# See Note [Seed and Offset]:
empty_tensor_int = op.Cast(
op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
),
to=INT64.dtype,
empty_tensor_int = op.ConstantOfShape(
op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64))
)

return logsum_exp, empty_tensor_int
Expand Down
1 change: 1 addition & 0 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: TID251
from __future__ import annotations

import dataclasses
Expand Down
22 changes: 13 additions & 9 deletions onnxscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
import sys
from typing import Any, Callable, Optional, Sequence

import onnx.helper

import onnxscript
from onnxscript import converter, irbuilder, values
from onnxscript import converter, ir, irbuilder, values
from onnxscript._internal import ast_utils


Expand Down Expand Up @@ -157,11 +155,17 @@
# Since we don't yet have LibProto defined, we use a ModelProto as a temporary
# container for the list of functions exported as a library, with an empty graph
# and dummy opset_imports.
model = onnx.helper.make_model(
onnx.GraphProto(),
functions=[f.to_function_proto() for f in functions],

# TODO(justinchuby): This function is not well supported. We should consider removing it
model = ir.Model(

Check warning on line 160 in onnxscript/main.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/main.py#L160

Added line #L160 was not covered by tests
ir.Graph(
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 15},
),
functions=[ir.serde.deserialize_function(f.to_function_proto()) for f in functions],
ir_version=10,
producer_name="p2o",
opset_imports=[onnx.helper.make_opsetid("", 15)],
)

onnx.save(model, filename)
ir.save(model, filename)

Check warning on line 171 in onnxscript/main.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/main.py#L171

Added line #L171 was not covered by tests
3 changes: 1 addition & 2 deletions onnxscript/onnx_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import ClassVar, Optional, Tuple, Union

import onnx
import onnx.helper

import onnxscript.ir

Expand Down Expand Up @@ -99,7 +98,7 @@ def to_type_proto(cls) -> onnx.TypeProto:
shape = cls.shape # example: "FLOAT[10,20]"
else:
shape = [cls.shape] # example: "FLOAT[10]"
return onnx.helper.make_tensor_type_proto(cls.dtype, shape)
return onnx.helper.make_tensor_type_proto(cls.dtype, shape) # noqa: TID251

@classmethod
def to_string(cls) -> str:
Expand Down
7 changes: 4 additions & 3 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,9 +829,10 @@ def _do_inference(self, node: ir.Node) -> None:

# TODO: handle optional inputs
def get_constant_value(x: ir.Value) -> onnx.TensorProto | None:
value = _get_numpy_value(x)
if isinstance(value, np.ndarray) and value.size < 20:
return onnx.numpy_helper.from_array(value, x.name)
value = _get_numpy_value(x, size_limit=20)
if value is not None:
assert x.const_value is not None
return ir.serde.serialize_tensor(x.const_value)
return None

def get_type(value: ir.Value) -> onnx.TypeProto | None:
Expand Down
1 change: 1 addition & 0 deletions onnxscript/optimizer/_legacy/constant_folding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: TID251
from __future__ import annotations

import logging
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/optimizer/_legacy/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# -------------------------------------------------------------------------

# ruff: noqa: TID251
from __future__ import annotations

import dataclasses
Expand Down
6 changes: 2 additions & 4 deletions onnxscript/rewriter/cast_constant_of_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import logging

import onnx.helper

from onnxscript import ir
from onnxscript.rewriter import pattern

Expand All @@ -20,7 +18,7 @@ def cast_constant_of_shape(op, shape, scalar, dtype):
def fused_cast_constant_of_shape(op, shape: ir.Value, scalar: ir.Attr, dtype: ir.Attr, **_):
# Cast scalar (a TensorProto attribute) to the specified dtype
scalar_value = scalar.value.numpy().item()
cast_value = onnx.helper.make_tensor("value", dtype.value, (1,), [scalar_value])
cast_value = ir.tensor([scalar_value], dtype=ir.DataType(dtype.as_int()))
return op.ConstantOfShape(shape, value=cast_value)


Expand All @@ -30,7 +28,7 @@ def cast_constant_of_shape_without_value(op, shape, dtype):


def fused_cast_constant_of_shape_without_value(op, shape, dtype, **_):
zero = onnx.helper.make_tensor("value", dtype.value, (1,), [0])
zero = ir.tensor([0], dtype=ir.DataType(dtype.as_int()))
return op.ConstantOfShape(shape, value=zero)


Expand Down
Loading
Loading