-
Notifications
You must be signed in to change notification settings - Fork 71
[IR] Improve pass infra #2120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] Improve pass infra #2120
Changes from all commits
d63df78
6fa351a
e2bed4a
0ced04d
c2b1385
3fffe61
f9ab36a
01cee46
ff947e9
3caf818
dd5d17b
2d77e36
7eea1f8
5f906c9
b38528a
460a41a
996551f
8703097
9b52641
0b059b9
2f39f59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,9 @@ | |
|
||
__all__ = [ | ||
"PassBase", | ||
"Sequential", | ||
"InPlacePass", | ||
"FunctionalPass", | ||
"PassManager", | ||
"PassResult", | ||
# Errors | ||
|
@@ -68,14 +71,72 @@ | |
class PassBase(abc.ABC): | ||
"""Base class for all passes. | ||
|
||
Class attributes: | ||
in_place: Whether the pass modifies the model in place. | ||
|
||
``in_place`` and ``changes_input`` properties and what they mean: | ||
|
||
+------------+------------------+----------------------------+ | ||
| | changes_inputs | not changes_inputs | | ||
+------------+------------------+----------------------------+ | ||
| in_place | in place | Side-effect-only pass | | ||
+------------+------------------+----------------------------+ | ||
| not | destructive | functional | | ||
| in_place | | | | ||
+------------+------------------+----------------------------+ | ||
""" | ||
|
||
in_place: bool = True | ||
@property | ||
@abc.abstractmethod | ||
def in_place(self) -> bool: | ||
"""Whether the pass modifies the model in place and returns it. | ||
|
||
If True, the pass will return the same model object that was passed in. | ||
If False, the pass will return a new model object. | ||
""" | ||
raise NotImplementedError | ||
|
||
@property | ||
@abc.abstractmethod | ||
def changes_input(self) -> bool: | ||
"""Whether the pass modifies input model.""" | ||
raise NotImplementedError | ||
|
||
@property | ||
def destructive(self) -> bool: | ||
"""Whether the pass will destroy the input model when ``in_place=False``. | ||
|
||
A pass is destructive if it is not in place and it modifies the input model. | ||
""" | ||
return not self.in_place and self.changes_input | ||
|
||
def __call__(self, model: ir.Model) -> PassResult: | ||
return self.call(model) | ||
# Check preconditions | ||
try: | ||
self.requires(model) | ||
except PreconditionError: | ||
raise | ||
except Exception as e: | ||
raise PreconditionError( | ||
f"Pre-condition for pass '{self.__class__.__name__}' failed" | ||
) from e | ||
|
||
result = self.call(model) | ||
|
||
# Check postconditions | ||
try: | ||
self.ensures(model) | ||
except PostconditionError: | ||
raise | ||
except Exception as e: | ||
raise PostconditionError( | ||
f"Post-condition for pass '{self.__class__.__name__}' failed" | ||
) from e | ||
|
||
if not isinstance(result, PassResult): | ||
raise TypeError( | ||
f"The result of the pass '{self.__class__.__name__}' should be type PassResult. " | ||
"Please create one with ir.passes.PassResult()." | ||
) | ||
return result | ||
|
||
@abc.abstractmethod | ||
def call(self, model: ir.Model) -> PassResult: | ||
|
@@ -97,76 +158,105 @@ | |
del model # Unused | ||
|
||
|
||
class PassManager: | ||
class InPlacePass(PassBase): | ||
"""A pass that modifies the input model in place and returns it.""" | ||
|
||
@property | ||
def in_place(self) -> bool: | ||
return True | ||
|
||
@property | ||
def changes_input(self) -> bool: | ||
return True | ||
|
||
|
||
class FunctionalPass(PassBase): | ||
"""A pass that returns a new model but does not modify the input model.""" | ||
|
||
@property | ||
def in_place(self) -> bool: | ||
return False | ||
|
||
@property | ||
def changes_input(self) -> bool: | ||
return False | ||
|
||
|
||
class Sequential(PassBase): | ||
"""Run a sequence of passes in order.""" | ||
|
||
def __init__(self, *passes: PassBase): | ||
if not passes: | ||
raise ValueError("Sequential must take at least one pass") | ||
self.passes = passes | ||
self._in_place = all(pass_.in_place for pass_ in passes) | ||
# The reason changes_inputs is decided by the first pass is that if the first pass is either in-place, | ||
# or if it is not designed to be in-place but somehow changes the input (destructive), | ||
# this pass sequence will change inputs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean would second or other passes that changes inputs after the first pass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the first pass is functional, the second pass will take the new model that the first pass returns, which means the second pass has no chance to affect the input model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But if the first pass is side-effect only pass. Shouldn't we check the following passes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. In that case we just assume it changes the model for now. I think that's ok? |
||
self._changes_input = self.passes[0].changes_input or self.passes[0].in_place | ||
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@property | ||
def in_place(self) -> bool: | ||
return self._in_place | ||
|
||
@property | ||
def changes_input(self) -> bool: | ||
return self._changes_input | ||
|
||
def call(self, model: ir.Model) -> PassResult: | ||
modified = False | ||
for i, pass_ in enumerate(self.passes): | ||
logger.debug("Running the %s-th pass '%s'", i, pass_) | ||
try: | ||
pass_result = pass_(model) | ||
except Exception as e: | ||
prev_pass_names = [str(p) for p in self.passes[:i]] | ||
raise PassError( | ||
f"An error occurred when running the '{pass_}' pass after the " | ||
f"following passes: {prev_pass_names}" | ||
) from e | ||
|
||
model = pass_result.model | ||
modified = modified or pass_result.modified | ||
|
||
return PassResult(model, modified) | ||
|
||
|
||
class PassManager(Sequential): | ||
"""Pass manager for the IR. | ||
|
||
The PassManager is a callable that runs a sequence of passes on a model. | ||
The PassManager is a Pass that runs a sequence of passes on a model. | ||
|
||
Attributes: | ||
passes: The passes to run. | ||
check_invariants: Whether to check invariants before and after each pass. | ||
steps: The number of times to run the passes. | ||
early_stop: Whether to stop running the passes if the graph stops changing. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
passes: Sequence[PassBase], | ||
check_invariants: bool = False, | ||
steps: int = 1, | ||
early_stop: bool = True, | ||
): | ||
# TODO(justinchuby): Implement constraints | ||
self.passes = list(passes) | ||
self.check_invariants = check_invariants | ||
super().__init__(*passes) | ||
self.steps = steps | ||
self.early_stop = early_stop | ||
|
||
def __call__(self, model: ir.Model) -> PassResult: | ||
def call(self, model: ir.Model) -> PassResult: | ||
"""Run the set of passes `steps` number of times or until the graph stops changing.""" | ||
overall_modified = False | ||
for step in range(self.steps): | ||
step_result = self._run_one_step(model, step) | ||
try: | ||
step_result = super().__call__(model) | ||
shubhambhokare1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except Exception as e: | ||
raise PassError(f"An error occurred at step {step}") from e | ||
model = step_result.model | ||
modified = step_result.modified | ||
overall_modified = overall_modified or modified | ||
# If the graph no longer changes, then we can stop running these passes | ||
if not modified: | ||
if not modified and self.early_stop: | ||
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.info("PassManager: No more graph changes detected after step %s", step) | ||
break | ||
return PassResult(model, overall_modified) | ||
|
||
def _run_one_step(self, model: ir.Model, step: int) -> PassResult: | ||
modified = False | ||
for i, pass_ in enumerate(self.passes): | ||
logger.debug("Running the %s-th pass '%s', (step %s)", i, pass_, step) | ||
|
||
# 1. Check preconditions | ||
if self.check_invariants: | ||
try: | ||
pass_.requires(model) | ||
except Exception as e: | ||
raise PreconditionError(f"Pre-condition failed for {pass_}") from e | ||
|
||
# 2. Run the pass | ||
try: | ||
pass_result = pass_(model) | ||
except Exception as e: | ||
prev_pass_names = [str(p) for p in self.passes[:i]] | ||
raise PassError( | ||
f"An error occurred when running the '{pass_}' pass after the " | ||
f"following passes: {prev_pass_names} during step {step}" | ||
) from e | ||
if not isinstance(pass_result, PassResult): | ||
raise TypeError( | ||
f"The result of the pass {pass_} should be type PassResult." | ||
"Please create one with ir.passes.PassResult()." | ||
) | ||
|
||
model = pass_result.model | ||
modified = modified or pass_result.modified | ||
|
||
# 3. Check postconditions | ||
if self.check_invariants: | ||
try: | ||
pass_.ensures(model) | ||
except Exception as e: | ||
raise PostconditionError(f"Post-condition failed for {pass_}") from e | ||
return PassResult(model, modified) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to add a paragraph in the documentation of the exporter (torch.onnx.export) to mention the list of passes applied to the fx graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!