Skip to content

Commit 63f3809

Browse files
authored
Use the new IR for building graphs in torchlib (#1354)
Use the new IR for building graphs in torchlib. I created a flag `TORCHLIB_EXPERIMENTAL_USE_IR` to control the feature. When enabled the classes in `graph_building` will be swapped to use the new IR. The exporter should see the same interfaces and not feel anything. After the transition period we can decide where this should live. Set `TORCHLIB_EXPERIMENTAL_USE_IR=1` to enable this feature. Fix #997
1 parent f06f303 commit 63f3809

File tree

10 files changed

+898
-7
lines changed

10 files changed

+898
-7
lines changed

.github/workflows/main.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ jobs:
3131
- py311-onnx-weekly
3232
- py311-ort-nightly
3333
- py311-experimental-torchlib-tracing
34+
- py311-experimental-torchlib-onnx-ir
3435
- py310
3536
- py39
3637
- py38
@@ -62,6 +63,9 @@ jobs:
6263
- name: py311-experimental-torchlib-tracing
6364
python-version: "3.11"
6465
nox-tag: test-experimental-torchlib-tracing
66+
- name: py311-experimental-torchlib-onnx-ir
67+
python-version: "3.11"
68+
nox-tag: test-experimental-torchlib-onnx-ir
6569
runs-on: ${{ matrix.os }}
6670
steps:
6771
- uses: actions/checkout@v4

noxfile.py

+21
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,24 @@ def test_experimental_torchlib_tracing(session):
126126
*session.posargs,
127127
env={"TORCHLIB_EXPERIMENTAL_PREFER_TRACING": "1"},
128128
)
129+
130+
131+
@nox.session(tags=["test-experimental-torchlib-onnx-ir"])
132+
def test_experimental_torchlib_onnx_ir(session):
133+
"""Test TorchLib using the ONNX IR to build graphs."""
134+
session.install(
135+
*COMMON_TEST_DEPENDENCIES,
136+
PYTORCH,
137+
TORCHVISON,
138+
ONNX,
139+
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
140+
)
141+
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
142+
session.install(".", "--no-deps")
143+
session.run("pip", "list")
144+
session.run(
145+
"pytest",
146+
"tests/function_libs/torch_lib/ops_test.py",
147+
*session.posargs,
148+
env={"TORCHLIB_EXPERIMENTAL_USE_IR": "1"},
149+
)

onnxscript/_internal/runtime_typing.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
installed.
66
"""
77

8+
import typing
89
import warnings
910

1011
__all__ = [
1112
"checked",
1213
]
1314

15+
T = typing.TypeVar("T", bound=typing.Callable[..., typing.Any])
16+
1417
try:
1518
from beartype import beartype as checked
1619
from beartype import roar as _roar
@@ -25,12 +28,12 @@
2528
)
2629
except ImportError:
2730

28-
def checked(func): # type: ignore[no-redef]
31+
def checked(func: T) -> T: # type: ignore[no-redef]
2932
return func
3033

3134
except Exception as e: # pylint: disable=broad-exception-caught
3235
# Warn errors that are not import errors (unexpected).
3336
warnings.warn(f"{e}", stacklevel=2)
3437

35-
def checked(func): # type: ignore[no-redef]
38+
def checked(func: T) -> T: # type: ignore[no-redef]
3639
return func

onnxscript/function_libs/torch_lib/_flags.py

+4
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,7 @@ def _load_boolean_flag(
4343
"TORCHLIB_EXPERIMENTAL_PREFER_TRACING",
4444
this_will="trace all traceable functions to fold if branches and collapse constant expressions",
4545
)
46+
EXPERIMENTAL_USE_IR: bool = _load_boolean_flag(
47+
"TORCHLIB_EXPERIMENTAL_USE_IR",
48+
this_will="use the ONNX IR instead of the PyTorch Graph for graph building",
49+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""APIs for building an ONNX graph from a PyTorch model.
2+
3+
This module exposes only three classes that will be used to build an ONNX graph
4+
by the ONNX exporter in PyTorch:
5+
6+
- :class:`TorchScriptTensor`: Represents a symbolic value in the ONNX graph.
7+
- :class:`TorchScriptGraph`: Stores the graph being built.
8+
- :class:`TorchScriptTracingEvaluator`: An evaluator that will record all operators
9+
applied on the ``TorchScriptTensor``. It has a reference to the ``TorchScriptGraph``
10+
being built, will write to it, and will handle eager evaluations of Torch Lib
11+
functions when desired.
12+
13+
The usage is in https://github.com/pytorch/pytorch/blob/136f8378e1b5a8cb7127977b8d068fbf9c3e1247/torch/onnx/_internal/fx/fx_onnx_interpreter.py#L698-L702,
14+
and it is very simple::
15+
16+
with onnxscript.evaluator.default_as(onnxscript_tracer): # onnxscript_tracer is a TorchScriptTracingEvaluator
17+
output: Union[
18+
onnxscript_graph_building.TorchScriptTensor,
19+
Tuple[onnxscript_graph_building.TorchScriptTensor, ...],
20+
] = symbolic_fn(*onnx_args, **onnx_kwargs)
21+
22+
Here, we set the default evaluator to be ``onnxscript_tracer`` so
23+
that ONNX Script will dispatch all operators calls to the evaluator. The ``symbolic_fn``
24+
can be a pure Python function (e.g. trace-only) or an ONNX Script function. Either way,
25+
they are recorded by ``onnxscript_tracer`` and onto the graph.
26+
27+
The outputs, as ``TorchScriptTensor``, are then handed by to the exporter. On line
28+
https://github.com/pytorch/pytorch/blob/136f8378e1b5a8cb7127977b8d068fbf9c3e1247/torch/onnx/_internal/fx/fx_onnx_interpreter.py#L707
29+
the exporter fills in type and shape information from PyTorch by calling the setters
30+
on ``TorchScriptTensor.dtype`` and ``TorchScriptTensor.shape``.
31+
"""
32+
33+
from __future__ import annotations
34+
35+
__all__ = [
36+
"TorchScriptTensor",
37+
"TorchScriptGraph",
38+
"TorchScriptTracingEvaluator",
39+
]
40+
41+
from onnxscript.function_libs.torch_lib import _flags
42+
43+
if _flags.EXPERIMENTAL_USE_IR:
44+
from ._graph_building_ir import (
45+
TorchScriptGraph,
46+
TorchScriptTensor,
47+
TorchScriptTracingEvaluator,
48+
)
49+
else:
50+
from ._graph_building_torch import ( # type: ignore[assignment]
51+
TorchScriptGraph,
52+
TorchScriptTensor,
53+
TorchScriptTracingEvaluator,
54+
)

0 commit comments

Comments
 (0)