Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.3.2
+++++

* :pr:`103`: fix import issue with the latest onnx version
* :pr:`101`: fix as_tensor in onnx_text_plot_tree

0.3.1
Expand Down
6 changes: 3 additions & 3 deletions _unittests/ut_npx/test_npx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,7 +1268,7 @@ def test_numpy_op_bin_reduce(self):
"xor", lambda x, y: (x.sum() == y.sum()) ^ (((-x).sum()) == y.sum())
)

def common_test_inline(self, fonx, fnp, tcst=0):
def common_test_inline(self, fonx, fnp, tcst=0, atol=1e-10):
f = fonx(Input("A"))
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={0: Float64[None], (0, False): Float64[None]})
Expand All @@ -1277,7 +1277,7 @@ def common_test_inline(self, fonx, fnp, tcst=0):
y = fnp(x)
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"A": x})
self.assertEqualArray(y, got[0], atol=1e-10)
self.assertEqualArray(y, got[0], atol=atol)

def common_test_inline_bin(self, fonx, fnp, tcst=0):
f = fonx(Input("A"), Input("B"))
Expand Down Expand Up @@ -1470,7 +1470,7 @@ def test_equal(self):

@unittest.skipIf(scipy is None, reason="scipy is not installed.")
def test_erf(self):
self.common_test_inline(erf_inline, scipy.special.erf)
self.common_test_inline(erf_inline, scipy.special.erf, atol=1e-7)

def test_exp(self):
self.common_test_inline(exp_inline, np.exp)
Expand Down
4 changes: 2 additions & 2 deletions _unittests/ut_translate_api/test_translate_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def _run(cls, code):
import onnx.helper
import onnx.numpy_helper
import onnx_array_api.translate_api.make_helper
import onnx.reference.custom_element_types
import ml_dtypes

def from_array_extended(tensor, name=None):
dt = tensor.dtype
Expand All @@ -433,7 +433,7 @@ def from_array_extended(tensor, name=None):
globs.update(onnx.helper.__dict__)
globs.update(onnx.numpy_helper.__dict__)
globs.update(onnx_array_api.translate_api.make_helper.__dict__)
globs.update(onnx.reference.custom_element_types.__dict__)
globs.update(ml_dtypes.__dict__)
globs["from_array_extended"] = from_array_extended
locs = {}
try:
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_xrun_doc/test_documentation_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run_test(self, fold: str, name: str, verbose=0) -> int:
cmds = [sys.executable, "-u", os.path.join(fold, name)]
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
res = p.communicate()
out, err = res
_out, err = res
st = err.decode("ascii", errors="ignore")
if st and "Traceback" in st:
if '"dot" not found in path.' in st:
Expand Down
7 changes: 6 additions & 1 deletion onnx_array_api/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,12 @@ def align_text(text, size):
return text[:h] + "..." + text[-h + 1 :]

dicts = self.as_dict(filter_node=filter_node, sort_key=sort_key)
max_nc = max(max(_["nc1"] for _ in dicts), max(_["nc2"] for _ in dicts))
set1 = [_["nc1"] for _ in dicts]
set2 = [_["nc1"] for _ in dicts]
if set1 or set2:
max_nc = max([*set1, *set2])
else:
max_nc = 1
dg = int(math.log(max_nc) / math.log(10) + 1.5)
line_format = (
"{indent}{fct} -- {nc1: %dd} {nc2: %dd} -- {tin:1.5f} {tall:1.5f}"
Expand Down
3 changes: 2 additions & 1 deletion onnx_array_api/translate_api/base_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
if value[0].type == AttributeProto.TENSOR:
repl = {"bool": "bool_", "object": "object_", "str": "str_"}
sdtype = repl.get(str(v.dtype), str(str(v.dtype)))
package = "np" if hasattr(np, sdtype) else "ml_dtypes"
return [], (
f"from_array(np.array({v.tolist()}, dtype=np.{sdtype}), "
f"from_array(np.array({v.tolist()}, dtype={package}.{sdtype}), "
f"name={value[0].name!r})"
)
if isinstance(v, (int, float, list)):
Expand Down
6 changes: 5 additions & 1 deletion onnx_array_api/translate_api/builder_emitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List
import numpy as np
from onnx import TensorProto
from onnx.numpy_helper import to_array
from .base_emitter import BaseEmitter
Expand Down Expand Up @@ -135,7 +136,10 @@ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
val = to_array(init)
stype = str(val.dtype).split(".")[-1]
name = self._clean_result_name(init.name)
rows.append(f" {name} = np.array({val.tolist()}, dtype=np.{stype})")
package = "np" if hasattr(np, stype) else "ml_dtypes"
rows.append(
f" {name} = np.array({val.tolist()}, dtype={package}.{stype})"
)
return rows

def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
Expand Down
5 changes: 3 additions & 2 deletions onnx_array_api/translate_api/inner_emitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from onnx import AttributeProto
from ..annotations import ELEMENT_TYPE_NAME
from .base_emitter import BaseEmitter
Expand Down Expand Up @@ -105,7 +106,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
else:
raise NotImplementedError(f"Unexpected dtype={sdtype}.")
else:
sdtype = f"np.{sdtype}"
sdtype = f"np.{sdtype}" if hasattr(np, sdtype) else f"ml_dtypes.{sdtype}"

return [
"initializers.append(",
Expand Down Expand Up @@ -233,7 +234,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
else:
raise NotImplementedError(f"Unexpected dtype={sdtype}.")
else:
sdtype = f"np.{sdtype}"
sdtype = f"np.{sdtype}" if hasattr(np, sdtype) else f"ml_dtypes.{sdtype}"
if value.size <= 16:
return [
"initializers.append(",
Expand Down
4 changes: 3 additions & 1 deletion onnx_array_api/translate_api/light_emitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List
import numpy as np
from ..annotations import ELEMENT_TYPE_NAME
from .base_emitter import BaseEmitter

Expand Down Expand Up @@ -43,8 +44,9 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
value = kwargs["value"]
repl = {"bool": "bool_", "object": "object_", "str": "str_"}
sdtype = repl.get(str(value.dtype), str(str(value.dtype)))
package = "np" if hasattr(np, sdtype) else "ml_dtypes"
return [
f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))",
f"cst(np.array({value.tolist()}, dtype={package}.{sdtype}))",
f"rename({name!r})",
]

Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ select = [
]

[tool.ruff.lint.per-file-ignores]
"**" = ["B905", "C401", "C408", "C413", "PYI041", "RUF012", "RUF100", "RUF010", "SIM108", "SIM910", "SIM110", "SIM102", "SIM114", "SIM103", "UP015", "UP027", "UP031", "UP034", "UP032", "UP006", "UP035", "UP007", "UP038"]
"**" = [
"B905", "C401", "C408", "C413", "PYI041",
"RUF012", "RUF100", "RUF010",
"SIM108", "SIM910", "SIM110", "SIM102", "SIM114", "SIM103",
"UP015", "UP027", "UP031", "UP034", "UP032", "UP006", "UP035", "UP007", "UP038", "UP045"
]
"**/plot*.py" = ["B018"]
"_doc/examples/plot_first_example.py" = ["E402", "F811"]
"_doc/examples/plot_onnxruntime.py" = ["E402", "F811"]
Expand Down