Skip to content

Commit 2771ad1

Browse files
justinchubyCopilot
authored andcommitted
[IR] Improve pass infra (microsoft#2120)
1. Run invariant functions `requires` and `ensures` by default at Pass `__call__` to match pytorch's pass behavior. This means the invariants cannot be too expensive because they are always checked. 2. Make PassManager a `Pass` so that it can be composed. 3. Add `changes_input` attribute to indicate if the input is changed. Turn two class attributes into properties for them to be dynamic. Combining the two attributes we can tell if a pass is destructive. For now the properties are unused but they will become useful when we want to have a better guard on pass usage etc. 4. Create `Sequential`, `InPlacePass`, `FunctionalPass` to help users create passes. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 6024d7c commit 2771ad1

File tree

6 files changed

+151
-58
lines changed

6 files changed

+151
-58
lines changed

onnxscript/ir/passes/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
"PassBase",
66
"PassResult",
77
"PassManager",
8+
"Sequential",
9+
"InPlacePass",
10+
"FunctionalPass",
811
# Errors
912
"InvariantError",
1013
"PreconditionError",
@@ -13,13 +16,16 @@
1316
]
1417

1518
from onnxscript.ir.passes._pass_infra import (
19+
FunctionalPass,
20+
InPlacePass,
1621
InvariantError,
1722
PassBase,
1823
PassError,
1924
PassManager,
2025
PassResult,
2126
PostconditionError,
2227
PreconditionError,
28+
Sequential,
2329
)
2430

2531

onnxscript/ir/passes/_pass_infra.py

+141-51
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121
__all__ = [
2222
"PassBase",
23+
"Sequential",
24+
"InPlacePass",
25+
"FunctionalPass",
2326
"PassManager",
2427
"PassResult",
2528
# Errors
@@ -68,14 +71,72 @@ class PassResult:
6871
class PassBase(abc.ABC):
6972
"""Base class for all passes.
7073
71-
Class attributes:
72-
in_place: Whether the pass modifies the model in place.
74+
75+
``in_place`` and ``changes_input`` properties and what they mean:
76+
77+
+------------+------------------+----------------------------+
78+
| | changes_inputs | not changes_inputs |
79+
+------------+------------------+----------------------------+
80+
| in_place | in place | Side-effect-only pass |
81+
+------------+------------------+----------------------------+
82+
| not | destructive | functional |
83+
| in_place | | |
84+
+------------+------------------+----------------------------+
7385
"""
7486

75-
in_place: bool = True
87+
@property
88+
@abc.abstractmethod
89+
def in_place(self) -> bool:
90+
"""Whether the pass modifies the model in place and returns it.
91+
92+
If True, the pass will return the same model object that was passed in.
93+
If False, the pass will return a new model object.
94+
"""
95+
raise NotImplementedError
96+
97+
@property
98+
@abc.abstractmethod
99+
def changes_input(self) -> bool:
100+
"""Whether the pass modifies input model."""
101+
raise NotImplementedError
102+
103+
@property
104+
def destructive(self) -> bool:
105+
"""Whether the pass will destroy the input model when ``in_place=False``.
106+
107+
A pass is destructive if it is not in place and it modifies the input model.
108+
"""
109+
return not self.in_place and self.changes_input
76110

77111
def __call__(self, model: ir.Model) -> PassResult:
78-
return self.call(model)
112+
# Check preconditions
113+
try:
114+
self.requires(model)
115+
except PreconditionError:
116+
raise
117+
except Exception as e:
118+
raise PreconditionError(
119+
f"Pre-condition for pass '{self.__class__.__name__}' failed"
120+
) from e
121+
122+
result = self.call(model)
123+
124+
# Check postconditions
125+
try:
126+
self.ensures(model)
127+
except PostconditionError:
128+
raise
129+
except Exception as e:
130+
raise PostconditionError(
131+
f"Post-condition for pass '{self.__class__.__name__}' failed"
132+
) from e
133+
134+
if not isinstance(result, PassResult):
135+
raise TypeError(
136+
f"The result of the pass '{self.__class__.__name__}' should be type PassResult. "
137+
"Please create one with ir.passes.PassResult()."
138+
)
139+
return result
79140

80141
@abc.abstractmethod
81142
def call(self, model: ir.Model) -> PassResult:
@@ -97,76 +158,105 @@ def ensures(self, model: ir.Model) -> None:
97158
del model # Unused
98159

99160

100-
class PassManager:
161+
class InPlacePass(PassBase):
162+
"""A pass that modifies the input model in place and returns it."""
163+
164+
@property
165+
def in_place(self) -> bool:
166+
return True
167+
168+
@property
169+
def changes_input(self) -> bool:
170+
return True
171+
172+
173+
class FunctionalPass(PassBase):
174+
"""A pass that returns a new model but does not modify the input model."""
175+
176+
@property
177+
def in_place(self) -> bool:
178+
return False
179+
180+
@property
181+
def changes_input(self) -> bool:
182+
return False
183+
184+
185+
class Sequential(PassBase):
186+
"""Run a sequence of passes in order."""
187+
188+
def __init__(self, *passes: PassBase):
189+
if not passes:
190+
raise ValueError("Sequential must take at least one pass")
191+
self.passes = passes
192+
self._in_place = all(pass_.in_place for pass_ in passes)
193+
# The reason changes_inputs is decided by the first pass is that if the first pass is either in-place,
194+
# or if it is not designed to be in-place but somehow changes the input (destructive),
195+
# this pass sequence will change inputs.
196+
self._changes_input = self.passes[0].changes_input or self.passes[0].in_place
197+
198+
@property
199+
def in_place(self) -> bool:
200+
return self._in_place
201+
202+
@property
203+
def changes_input(self) -> bool:
204+
return self._changes_input
205+
206+
def call(self, model: ir.Model) -> PassResult:
207+
modified = False
208+
for i, pass_ in enumerate(self.passes):
209+
logger.debug("Running the %s-th pass '%s'", i, pass_)
210+
try:
211+
pass_result = pass_(model)
212+
except Exception as e:
213+
prev_pass_names = [str(p) for p in self.passes[:i]]
214+
raise PassError(
215+
f"An error occurred when running the '{pass_}' pass after the "
216+
f"following passes: {prev_pass_names}"
217+
) from e
218+
219+
model = pass_result.model
220+
modified = modified or pass_result.modified
221+
222+
return PassResult(model, modified)
223+
224+
225+
class PassManager(Sequential):
101226
"""Pass manager for the IR.
102227
103-
The PassManager is a callable that runs a sequence of passes on a model.
228+
The PassManager is a Pass that runs a sequence of passes on a model.
104229
105230
Attributes:
106231
passes: The passes to run.
107-
check_invariants: Whether to check invariants before and after each pass.
108232
steps: The number of times to run the passes.
233+
early_stop: Whether to stop running the passes if the graph stops changing.
109234
"""
110235

111236
def __init__(
112237
self,
113238
passes: Sequence[PassBase],
114-
check_invariants: bool = False,
115239
steps: int = 1,
240+
early_stop: bool = True,
116241
):
117242
# TODO(justinchuby): Implement constraints
118-
self.passes = list(passes)
119-
self.check_invariants = check_invariants
243+
super().__init__(*passes)
120244
self.steps = steps
245+
self.early_stop = early_stop
121246

122-
def __call__(self, model: ir.Model) -> PassResult:
247+
def call(self, model: ir.Model) -> PassResult:
123248
"""Run the set of passes `steps` number of times or until the graph stops changing."""
124249
overall_modified = False
125250
for step in range(self.steps):
126-
step_result = self._run_one_step(model, step)
251+
try:
252+
step_result = super().__call__(model)
253+
except Exception as e:
254+
raise PassError(f"An error occurred at step {step}") from e
127255
model = step_result.model
128256
modified = step_result.modified
129257
overall_modified = overall_modified or modified
130258
# If the graph no longer changes, then we can stop running these passes
131-
if not modified:
259+
if not modified and self.early_stop:
132260
logger.info("PassManager: No more graph changes detected after step %s", step)
133261
break
134262
return PassResult(model, overall_modified)
135-
136-
def _run_one_step(self, model: ir.Model, step: int) -> PassResult:
137-
modified = False
138-
for i, pass_ in enumerate(self.passes):
139-
logger.debug("Running the %s-th pass '%s', (step %s)", i, pass_, step)
140-
141-
# 1. Check preconditions
142-
if self.check_invariants:
143-
try:
144-
pass_.requires(model)
145-
except Exception as e:
146-
raise PreconditionError(f"Pre-condition failed for {pass_}") from e
147-
148-
# 2. Run the pass
149-
try:
150-
pass_result = pass_(model)
151-
except Exception as e:
152-
prev_pass_names = [str(p) for p in self.passes[:i]]
153-
raise PassError(
154-
f"An error occurred when running the '{pass_}' pass after the "
155-
f"following passes: {prev_pass_names} during step {step}"
156-
) from e
157-
if not isinstance(pass_result, PassResult):
158-
raise TypeError(
159-
f"The result of the pass {pass_} should be type PassResult."
160-
"Please create one with ir.passes.PassResult()."
161-
)
162-
163-
model = pass_result.model
164-
modified = modified or pass_result.modified
165-
166-
# 3. Check postconditions
167-
if self.check_invariants:
168-
try:
169-
pass_.ensures(model)
170-
except Exception as e:
171-
raise PostconditionError(f"Post-condition failed for {pass_}") from e
172-
return PassResult(model, modified)

onnxscript/ir/passes/common/shape_inference.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,9 @@
2222
_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
2323

2424

25-
class ShapeInferencePass(ir.passes.PassBase):
25+
class ShapeInferencePass(ir.passes.FunctionalPass):
2626
"""This pass performs shape inference on the graph."""
2727

28-
# This pass does not modify the model in place.
29-
in_place = False
30-
3128
def __init__(
3229
self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True
3330
) -> None:

onnxscript/optimizer/_constant_folding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ def merge_dims(dim1, dim2):
797797
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
798798

799799

800-
class FoldConstantsPass(ir.passes.PassBase):
800+
class FoldConstantsPass(ir.passes.InPlacePass):
801801
def __init__(
802802
self,
803803
*,

onnxscript/optimizer/_remove_unused.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int
8282
return count
8383

8484

85-
class RemoveUnusedNodesPass(ir.passes.PassBase):
85+
class RemoveUnusedNodesPass(ir.passes.InPlacePass):
8686
def call(self, model: ir.Model) -> ir.passes.PassResult:
8787
count = _process_function_or_graph(model.graph)
8888
graph_outputs = frozenset(model.graph.outputs)

onnxscript/optimizer/_remove_unused_function.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _clean_up_unused_functions(model: ir.Model, unused: set[ir.OperatorIdentifie
2525
logger.debug("Functions removed: %s", unused)
2626

2727

28-
class RemoveUnusedFunctionPass(ir.passes.PassBase):
28+
class RemoveUnusedFunctionPass(ir.passes.InPlacePass):
2929
def __init__(self):
3030
super().__init__()
3131
self.used: set[ir.OperatorIdentifier] | None = None

0 commit comments

Comments
 (0)