forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[custom op] Generalize shape library logic to work with dtypes (llvm#…
…1594) * [custom op] Generalize shape library logic to work with dtypes This commit generalizes the shape library logic, so that dtype rules for ops can also be expressed using the same mechanism. In other words, each op can now have a shape function and a dtype function specified in Python that is imported during lowering to calculate the shapes and dtypes throught a program. For more information about how to specify a dtype function, see the updated `docs/adding_a_shape_and_dtype_function.md`. For those not familiar with how the shape library works, the file `docs/calculations_lib.md` provides an overview.
- Loading branch information
Showing
45 changed files
with
3,739 additions
and
2,312 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Torch-MLIR Abstract Interpretation Library Infrastructure | ||
|
||
## Overview | ||
|
||
The Torch-MLIR project has an infrastructure for maintaining a library of | ||
calculation functions for different Torch operators, which supply extra | ||
information such as result dtypes and shapes as well as decompositions. These | ||
functions are fully executable specifications of the shape/dtype/decomposition | ||
functions for each operator and can be authored and tested from Python for | ||
convenience. These are then brought into the compiler and can be manipulated / | ||
transformed for various purposes. Additionally, in the case of shape functions, | ||
this effort is synergistic with upstream PyTorch efforts to maintain a library | ||
of shape functions. | ||
|
||
The two main use cases are: | ||
|
||
- Refinement / inference. The `torch-shape-refinement-pipeline` and | ||
`torch-dtype-refinement-pipeline` pass pipelines orchestrate a series of | ||
passes that use the available information in the program to further refine the | ||
types in the program. | ||
|
||
- Error guard insertion for backends (Not Yet Implemented). The executable | ||
functions can include error guards / assertions that abort the program in case | ||
of invalid input (such as a matmul with a mismatching contracting dimension). | ||
|
||
## Architecture | ||
|
||
Functions are defined as TorchScript-able Python functions in | ||
`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py`. | ||
The signatures of the functions are systematically derived from Torch JIT | ||
operator registry. Most shape functions are expected to reuse the upstream | ||
helper functions | ||
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1), | ||
and any new shape functions should be added there. | ||
|
||
The `build_tools/update_abstract_interp_lib.sh` script invokes | ||
`abstract_interp_lib_gen.py` to generate an MLIR module containing the functions, | ||
which is currently embedded as a string literal in | ||
`lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp`. | ||
|
||
The function `StringRef mlir::torch::Torch::getAbstractInterpLibrary()` is | ||
available for use inside the compiler any time that the library is needed. | ||
|
||
## Shape and Dtype Refinement Pipeline Architecture | ||
|
||
One of the main services that Torch-MLIR provides for backends is to normalize | ||
all Torch frontends into a common form which includes tensor shapes and dtypes | ||
that are as precise as possible. This alleviates the need for backends to solve | ||
this problem themselves. This process of shape and dtype refinement is | ||
accomplished in Torch-MLIR through a pipeline of passes which uses the abstract | ||
interpretation library combined with abstract interpretation of the calculation | ||
functions to calculate shapes and dtypes that are as precise as possible. | ||
|
||
The pipeline works as follows: | ||
|
||
1. Calculations are reified. The `torch-reify-shape-calculations` and | ||
`torch-reify-dtype-calculations` passes reify (i.e., materializes into the | ||
IR) the functions for each op with a function in the calculation library. To | ||
do this, the passes wrap those ops in a `torch.shape.calculate` or | ||
`torch.dtype.calculate` op, respectively, which has two regions: 1) a body | ||
with the op itself, and 2) the shape or dtype calculation, which calculates | ||
the shapes or dtypes of the tensors yielded by the body. | ||
|
||
2. Simplifying the functions and propagating the shapes and dtypes. After the | ||
functions are reified, we then attempt to "optimize hard enough" until the | ||
shapes and dtypes yielded by the calculation regions become obvious in the IR. | ||
Those results are propagated through the IR, which usually reveals more | ||
opportunities for simplification. | ||
|
||
a. After reification, the functions are just a loose collection of | ||
functions, which are difficult to analyze. The first step is to inline them. | ||
|
||
b. After inlining, the `torch-simplify-shape-calculations` and | ||
`torch-simplify-dtype-calculations` passes are used to simplify the | ||
calculations. These passes bring in a number of targeted canonicalization | ||
patterns and folds, along with a few specific patterns such as unrolling | ||
fixed-trip-count loops and abstractly interpreting list operations (an | ||
example is turning a series of "append" operations into a list | ||
literal). These passes also look at the values yielded by the calculation | ||
regions, and if the resulting shape or dtype can be deduced by looking at the | ||
IR (for example, the shape is the list literal `[1, 2, 3]`), it will refine | ||
the types of the `torch.shape.calculate` and `torch.dtype.calculate` | ||
ops. This usually yields more opportunities for simplification. This process | ||
runs to a fixed-point. | ||
|
||
3. Dropping the calculations. Once all the types in the program have been | ||
refined as much as possible, the ops that were originally wrapped in | ||
`torch.shape.calculate` and `torch.dtype.calculate` are unwrapped by the | ||
`torch-drop-abstract-interp-calculations` pass which drops the reified | ||
calculations, leaving behind the shape and dtype refined program. | ||
|
||
Inferring precise shapes and dtypes often is needed for correctness by | ||
backends. That said, requiring "optimizing hard enough" for correctness is | ||
usually considered quite brittle in a compiler flow. In this case, the saving | ||
grace is that we are only optimizing the functions, which are authored by | ||
compiler developers (not users), and thus there is some give-and-take in terms | ||
of understanding the optimizable constructs while authoring the functions, or | ||
improving the optimizations to enable easier authoring. Some brittleness is | ||
likely to escape to users, unfortunately, since there will always be situations | ||
where, for example, a statically shaped program allows the shape functions to be | ||
simplified to a greater extent than in a dynamically shaped program (for | ||
example, if the shape function checks "is this dimension of size 1"). We hope | ||
that this is minimal. | ||
|
||
## Adding to the abstract interpretation library | ||
|
||
See [Adding a Shape and Dtype Function](adding_a_shape_and_dtype_function.md) | ||
for details on how to add a shape and dtype function for an operator. | ||
|
||
## Rationale | ||
|
||
### Use of full operator signatures | ||
|
||
The use of the full operator signature such as | ||
`def aten〇add〇Tensor(self: List[int], other: List[int], alpha: float = 1) -> List[int]:` | ||
for defining calculation functions is somewhat verbose and repetitive, especially when | ||
there are multiple identical functions. Upstream uses a map with key-value | ||
pairs like `"aten.add.Tensor": upstream_shape_functions.broadcast`, which is | ||
more compact and less repetitive in some ways (upstream also allows trailing | ||
arguments beyond those accepted by the shape function to be ignored, allowing | ||
further deduplication). The decision to do it the more verbose way in Torch-MLIR | ||
was based on the following goals: | ||
|
||
- To make the system very easy to debug and test. | ||
|
||
- To make the system maximally consistent between functions that are | ||
implemented with the upstream shape helpers and the ones that are manually | ||
written, which are still a fairly large and non-trivial set. | ||
|
||
- To make it as mechanical as possible to add a new function. |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.