20
20
21
21
__all__ = [
22
22
"PassBase" ,
23
+ "Sequential" ,
24
+ "InPlacePass" ,
25
+ "FunctionalPass" ,
23
26
"PassManager" ,
24
27
"PassResult" ,
25
28
# Errors
@@ -68,14 +71,72 @@ class PassResult:
68
71
class PassBase (abc .ABC ):
69
72
"""Base class for all passes.
70
73
71
- Class attributes:
72
- in_place: Whether the pass modifies the model in place.
74
+
75
+ ``in_place`` and ``changes_input`` properties and what they mean:
76
+
77
+ +------------+------------------+----------------------------+
78
+ | | changes_inputs | not changes_inputs |
79
+ +------------+------------------+----------------------------+
80
+ | in_place | in place | Side-effect-only pass |
81
+ +------------+------------------+----------------------------+
82
+ | not | destructive | functional |
83
+ | in_place | | |
84
+ +------------+------------------+----------------------------+
73
85
"""
74
86
75
- in_place : bool = True
87
+ @property
88
+ @abc .abstractmethod
89
+ def in_place (self ) -> bool :
90
+ """Whether the pass modifies the model in place and returns it.
91
+
92
+ If True, the pass will return the same model object that was passed in.
93
+ If False, the pass will return a new model object.
94
+ """
95
+ raise NotImplementedError
96
+
97
+ @property
98
+ @abc .abstractmethod
99
+ def changes_input (self ) -> bool :
100
+ """Whether the pass modifies input model."""
101
+ raise NotImplementedError
102
+
103
+ @property
104
+ def destructive (self ) -> bool :
105
+ """Whether the pass will destroy the input model when ``in_place=False``.
106
+
107
+ A pass is destructive if it is not in place and it modifies the input model.
108
+ """
109
+ return not self .in_place and self .changes_input
76
110
77
111
def __call__ (self , model : ir .Model ) -> PassResult :
78
- return self .call (model )
112
+ # Check preconditions
113
+ try :
114
+ self .requires (model )
115
+ except PreconditionError :
116
+ raise
117
+ except Exception as e :
118
+ raise PreconditionError (
119
+ f"Pre-condition for pass '{ self .__class__ .__name__ } ' failed"
120
+ ) from e
121
+
122
+ result = self .call (model )
123
+
124
+ # Check postconditions
125
+ try :
126
+ self .ensures (model )
127
+ except PostconditionError :
128
+ raise
129
+ except Exception as e :
130
+ raise PostconditionError (
131
+ f"Post-condition for pass '{ self .__class__ .__name__ } ' failed"
132
+ ) from e
133
+
134
+ if not isinstance (result , PassResult ):
135
+ raise TypeError (
136
+ f"The result of the pass '{ self .__class__ .__name__ } ' should be type PassResult. "
137
+ "Please create one with ir.passes.PassResult()."
138
+ )
139
+ return result
79
140
80
141
@abc .abstractmethod
81
142
def call (self , model : ir .Model ) -> PassResult :
@@ -97,76 +158,105 @@ def ensures(self, model: ir.Model) -> None:
97
158
del model # Unused
98
159
99
160
100
- class PassManager :
161
+ class InPlacePass (PassBase ):
162
+ """A pass that modifies the input model in place and returns it."""
163
+
164
+ @property
165
+ def in_place (self ) -> bool :
166
+ return True
167
+
168
+ @property
169
+ def changes_input (self ) -> bool :
170
+ return True
171
+
172
+
173
+ class FunctionalPass (PassBase ):
174
+ """A pass that returns a new model but does not modify the input model."""
175
+
176
+ @property
177
+ def in_place (self ) -> bool :
178
+ return False
179
+
180
+ @property
181
+ def changes_input (self ) -> bool :
182
+ return False
183
+
184
+
185
+ class Sequential (PassBase ):
186
+ """Run a sequence of passes in order."""
187
+
188
+ def __init__ (self , * passes : PassBase ):
189
+ if not passes :
190
+ raise ValueError ("Sequential must take at least one pass" )
191
+ self .passes = passes
192
+ self ._in_place = all (pass_ .in_place for pass_ in passes )
193
+ # The reason changes_inputs is decided by the first pass is that if the first pass is either in-place,
194
+ # or if it is not designed to be in-place but somehow changes the input (destructive),
195
+ # this pass sequence will change inputs.
196
+ self ._changes_input = self .passes [0 ].changes_input or self .passes [0 ].in_place
197
+
198
+ @property
199
+ def in_place (self ) -> bool :
200
+ return self ._in_place
201
+
202
+ @property
203
+ def changes_input (self ) -> bool :
204
+ return self ._changes_input
205
+
206
+ def call (self , model : ir .Model ) -> PassResult :
207
+ modified = False
208
+ for i , pass_ in enumerate (self .passes ):
209
+ logger .debug ("Running the %s-th pass '%s'" , i , pass_ )
210
+ try :
211
+ pass_result = pass_ (model )
212
+ except Exception as e :
213
+ prev_pass_names = [str (p ) for p in self .passes [:i ]]
214
+ raise PassError (
215
+ f"An error occurred when running the '{ pass_ } ' pass after the "
216
+ f"following passes: { prev_pass_names } "
217
+ ) from e
218
+
219
+ model = pass_result .model
220
+ modified = modified or pass_result .modified
221
+
222
+ return PassResult (model , modified )
223
+
224
+
225
+ class PassManager (Sequential ):
101
226
"""Pass manager for the IR.
102
227
103
- The PassManager is a callable that runs a sequence of passes on a model.
228
+ The PassManager is a Pass that runs a sequence of passes on a model.
104
229
105
230
Attributes:
106
231
passes: The passes to run.
107
- check_invariants: Whether to check invariants before and after each pass.
108
232
steps: The number of times to run the passes.
233
+ early_stop: Whether to stop running the passes if the graph stops changing.
109
234
"""
110
235
111
236
def __init__ (
112
237
self ,
113
238
passes : Sequence [PassBase ],
114
- check_invariants : bool = False ,
115
239
steps : int = 1 ,
240
+ early_stop : bool = True ,
116
241
):
117
242
# TODO(justinchuby): Implement constraints
118
- self .passes = list (passes )
119
- self .check_invariants = check_invariants
243
+ super ().__init__ (* passes )
120
244
self .steps = steps
245
+ self .early_stop = early_stop
121
246
122
- def __call__ (self , model : ir .Model ) -> PassResult :
247
+ def call (self , model : ir .Model ) -> PassResult :
123
248
"""Run the set of passes `steps` number of times or until the graph stops changing."""
124
249
overall_modified = False
125
250
for step in range (self .steps ):
126
- step_result = self ._run_one_step (model , step )
251
+ try :
252
+ step_result = super ().__call__ (model )
253
+ except Exception as e :
254
+ raise PassError (f"An error occurred at step { step } " ) from e
127
255
model = step_result .model
128
256
modified = step_result .modified
129
257
overall_modified = overall_modified or modified
130
258
# If the graph no longer changes, then we can stop running these passes
131
- if not modified :
259
+ if not modified and self . early_stop :
132
260
logger .info ("PassManager: No more graph changes detected after step %s" , step )
133
261
break
134
262
return PassResult (model , overall_modified )
135
-
136
- def _run_one_step (self , model : ir .Model , step : int ) -> PassResult :
137
- modified = False
138
- for i , pass_ in enumerate (self .passes ):
139
- logger .debug ("Running the %s-th pass '%s', (step %s)" , i , pass_ , step )
140
-
141
- # 1. Check preconditions
142
- if self .check_invariants :
143
- try :
144
- pass_ .requires (model )
145
- except Exception as e :
146
- raise PreconditionError (f"Pre-condition failed for { pass_ } " ) from e
147
-
148
- # 2. Run the pass
149
- try :
150
- pass_result = pass_ (model )
151
- except Exception as e :
152
- prev_pass_names = [str (p ) for p in self .passes [:i ]]
153
- raise PassError (
154
- f"An error occurred when running the '{ pass_ } ' pass after the "
155
- f"following passes: { prev_pass_names } during step { step } "
156
- ) from e
157
- if not isinstance (pass_result , PassResult ):
158
- raise TypeError (
159
- f"The result of the pass { pass_ } should be type PassResult."
160
- "Please create one with ir.passes.PassResult()."
161
- )
162
-
163
- model = pass_result .model
164
- modified = modified or pass_result .modified
165
-
166
- # 3. Check postconditions
167
- if self .check_invariants :
168
- try :
169
- pass_ .ensures (model )
170
- except Exception as e :
171
- raise PostconditionError (f"Post-condition failed for { pass_ } " ) from e
172
- return PassResult (model , modified )
0 commit comments