-
Notifications
You must be signed in to change notification settings - Fork 522
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[custom op] Generalize shape library logic to work with dtypes
This commit generalizes the shape library logic, so that dtype rules for ops can also be expressed using the same mechanism. In other words, each op can now have a shape function and a dtype function specified in Python that is imported during lowering to calculate the shapes and dtypes throught a program. For more information about how to specify a dtype function, see the updated `docs/adding_a_shape_and_dtype_function.md`. For those not familiar with how the shape library works, the file `docs/calculations_lib.md` provides an overview. To make the reviewing a bit easier, I suggest the following review order: 1. Get familiar with the overall architecture by reading `docs/calculations_lib.md` 2. New op declarations - `include/torch-mlir/Dialect/Torch/IR/TorchOps.td` - `lib/Dialect/Torch/IR/TorchOps.cpp` 3. New passes - `include/torch-mlir/Dialect/Torch/Transforms/Passes.td` - `include/torch-mlir/Dialect/Torch/Transforms/Passes.h` - `lib/Dialect/Torch/Transforms/Passes.cpp` - `lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp` - `lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp` - `lib/Dialect/Torch/Transforms/ReifyCalculationsUtils.cpp` - `lib/Dialect/Torch/Transforms/ReifyCalculationsUtils.h` - `lib/Dialect/Torch/Transforms/SimplifyCalculationsUtils.cpp` - `lib/Dialect/Torch/Transforms/SimplifyCalculationsUtils.h` The `*Utils.*` files include logic that is shared by dtype and shape passes. 4. Tests - `test/Dialect/Torch/ops.mlir` - `test/Dialect/Torch/reify-dtype-calculations.mlir` - `test/Dialect/Torch/simplify-dtype-calculations.mlir` 5. Introduce `torch_mlir_promote_dtypes` - `python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp` - `python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py` 6. Simple refactoring/generalizing - `include/torch-mlir/Dialect/Torch/Utils/Utils.h` - `lib/Dialect/Torch/Utils/Utils.cpp` - `lib/Dialect/Torch/Transforms/DropCalculations.cpp` - `lib/Dialect/Torch/Transforms/RefineTypes.cpp` - `lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp` - `lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp` - `python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py` - `python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py` - `python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/calculations_lib_gen.py` - `lib/Dialect/Torch/Transforms/CalculationsLibrary.cpp` 7. The rest of the files include minor changes - Replace `shape` with `calculations` - Replace `m_TorchConstantIntList` with `m_TorchListOfConstantInts` (needed to avoid ambiguity with new pattern `m_TorchListOfOptionalConstantInts`)
- Loading branch information
Showing
56 changed files
with
3,657 additions
and
2,338 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Adding a Shape and Dtype Function | ||
|
||
## Overview | ||
|
||
As part of adding support for a Torch operator in Torch-MLIR, it is | ||
usually necessary to define a shape and dtype function so that the | ||
compiler can infer the shapes and dtypes of result tensors for the | ||
operator. We use the [calculations library](calculations_lib.md) for | ||
this process. | ||
|
||
## Step-by-step guide | ||
|
||
We will use the example of adding support for the `torch.aten.tanh` op. | ||
|
||
1. First, you need to find the shape and dtype function signatures for | ||
the operator you are implementing a functions for. This can be | ||
found in | ||
`include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt` | ||
generated by the `build_tools/update_torch_ods.sh` script. That | ||
file is the "rosetta stone" that allows translating between | ||
e.g. `torch.aten.tanh`, `AtenTanhOp`, and the shape and dtype | ||
function signatures are: | ||
|
||
- `def aten〇tanh〡shape(self: List[int]) -> List[int]:` | ||
- `def aten〇tan_〡dtype(self_rank: int, self_dtype: int) -> int:` | ||
|
||
Note the use of `〇` as a separator since `.` or `::` aren't legal | ||
in a Python identifier. | ||
|
||
2. Paste the function signature into `calculations_lib_gen.py` in an | ||
appropriate place (ideally near other functions with a similar | ||
functions). Note that `calculations_lib_gen.py` will check that | ||
these signatures are verbatim identical with the ones given in | ||
`JITOperatorRegistryDump.txt` -- this ensures that the functions | ||
don't get outdated if Torch changes an operator signature. | ||
|
||
3. Fill in the body of the function. Ideally this will just be a call | ||
into a helper function from | ||
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1). | ||
But in general, you will need to write the function and test it | ||
(see the comments about "Shape, dtype, and decomposition function | ||
testing infrastructure" in `testing_framework.py`). New shape | ||
functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889), | ||
though it can be useful to iterate locally in `calculations_lib_gen.py` | ||
first. | ||
|
||
4. Re-run the `build_tools/update_calculations_lib.sh` script to | ||
update the library. After this step happens, ideally everything | ||
"just works" and the functions are now correctly inferred for the | ||
operator. | ||
|
||
## When things go wrong | ||
|
||
It is possible that the refinement pipeline (see [Shape and Dtype Refinement Pipeline Architecture](calculations_lib.md#shape-and-dtype-refinement-pipeline-architecture)) | ||
is not able to infer the shape/dtype of a tensor with a given | ||
calculation function. This usually means that there is something about | ||
the function which the optimizations in | ||
`torch-simplify-shape/dtype-functions` | ||
(`lib/Dialect/Torch/Transforms/Simplify{Shape/Dtype}Calculations.cpp`) | ||
cannot handle. | ||
|
||
To debug this, the overall goal is to pinpoint the IR construct that | ||
is not being simplified. This is usually accomplished by a combination | ||
of looking at the Python code for the function and the IR dumps. The | ||
best IR dump to look at varies, but frequently the IR dump right | ||
before `DropCalculations` is the most useful, because it has already | ||
been simplified as much as possible, making it is easy to see what is | ||
blocking further simplification. Examples of issues you might see: | ||
|
||
- You might find that there is a loop with a non-constant trip count, | ||
but based on your understanding of the function, you would expect it | ||
to be simplified to a constant trip count -- you can then look at | ||
the trip count calculation and see if there is a missing fold or | ||
canonicalization. | ||
|
||
- You might find that there is a list operation that is not currently understood | ||
by the optimizations. You can then teach the optimizations about that | ||
operation. | ||
|
||
- You might find that there is an `Optional` value that you would | ||
expect to be resolved to either a concrete value or `None`. You can | ||
then look at the calculation that produces the optional value and | ||
see what folds or canonicalizations are missing. | ||
|
||
See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general | ||
guidance on debugging Torch-MLIR. | ||
|
||
As a last resort, you can rewrite the function using constructs that | ||
`torch-simplify-shape/dtype-functions` can handle (look at other | ||
functions for examples, sometimes it requires writing things a little | ||
awkwardly). |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Torch-MLIR Calculations Library Infrastructure | ||
|
||
## Overview | ||
|
||
The Torch-MLIR project has an infrastructure for maintaining a library of | ||
calculation functions for different Torch operators, which supply extra | ||
information such as result dtypes and shapes as well as decompositions. These | ||
functions are fully executable specifications of the shape/dtype/decomposition | ||
functions for each operator and can be authored and tested from Python for | ||
convenience. These are then brought into the compiler and can be manipulated / | ||
transformed for various purposes. Additionally, in the case of shape functions, | ||
this effort is synergistic with upstream PyTorch efforts to maintain a library | ||
of shape functions. | ||
|
||
The two main use cases are: | ||
|
||
- Refinement / inference. The `torch-shape/dtype-refinement-pipeline` pass | ||
pipeline orchestrates a series of passes that use the available information in | ||
the program to further refine the types in the program. | ||
|
||
- Error guard insertion for backends (Not Yet Implemented). The executable | ||
functions can include error guards / assertions that abort the program in case | ||
of invalid input (such as a matmul with a mismatching contracting dimension). | ||
|
||
## Architecture | ||
|
||
Functions are defined as TorchScript-able Python functions in | ||
`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/calculations_lib_gen.py`. | ||
The signatures of the functions are systematically derived from Torch JIT | ||
operator registry. Most shape functions are expected to reuse the upstream | ||
helper functions | ||
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1), | ||
and any new shape functions should be added there. | ||
|
||
The `build_tools/update_calculations_lib.sh` script invokes | ||
`calculations_lib_gen.py` to generate an MLIR module containing the functions, | ||
which is currently embedded as a string literal in | ||
`lib/Dialect/Torch/Transforms/CalculationsLibrary.cpp`. | ||
|
||
The function `StringRef mlir::torch::Torch::getCalculationsLibrary()` is | ||
available for use inside the compiler any time that the library is needed. | ||
|
||
## Shape and Dtype Refinement Pipeline Architecture | ||
|
||
One of the main services that Torch-MLIR provides for backends is to normalize | ||
all Torch frontends into a common form which includes tensor shapes and dtypes | ||
that are as precise as possible. This alleviates the need for backends to solve | ||
this problem themselves. This process of shape and dtype refinement is | ||
accomplished in Torch-MLIR through a pipeline of passes which uses the | ||
calculations library combined with abstract interpretation of the calculation | ||
functions to calculate shapes and dtypes that are as precise as possible. | ||
|
||
The pipeline works as follows: | ||
|
||
1. Calculations are reified. The `torch-reify-shape/dtype-calculations` reifies | ||
(i.e., materializes into the IR) the functions for each op with a function in | ||
the calculation library. To do this, it wraps those ops in a | ||
`torch.shape/dtype.calculate` op, which has two regions: 1) a body with the | ||
op itself, and 2) the shape or dtype calculation, which calculates the shapes | ||
or dtypes of the tensors yielded by the body. | ||
|
||
2. Simplifying the functions and propagating the shapes and dtypes. After the | ||
functions are reified, we then attempt to "optimize hard enough" until the | ||
shapes and dtypes yielded by the calculation regions become obvious in the IR. | ||
Those results are propagated through the IR, which usually reveals more | ||
opportunities for simplification. | ||
|
||
a. After reification, the functions are just a loose collection of | ||
functions, which are difficult to analyze. The first step is to inline them. | ||
|
||
b. After inlining, the `torch-simplify-shape/dtype-calculations` pass is used | ||
to simplify the calculations. This pass brings in a number of targeted | ||
canonicalization patterns and folds, along with a few specific patterns such | ||
as unrolling fixed-trip-count loops and abstractly interpreting list | ||
operations (an example is turning a series of "append" operations into a list | ||
literal). This pass also looks at the values yielded by the calculation | ||
regions, and if the resulting shape or dtype can be deduced by looking at the | ||
IR (for example, the shape is the list literal `[1, 2, 3]`), it will refine | ||
the types of the `torch.shape/dtype.calculate` op. This usually yields more | ||
opportunities for simplification. This process runs to a fixed-point. | ||
|
||
3. Dropping the calculations. Once all the types in the program have been | ||
refined as much as possible, the ops that were originally wrapped in | ||
`torch.shape/dtype.calculate` are unwrapped by the `torch-drop-calculations` | ||
pass which drops the reified calculations, leaving behind the shape and dtype | ||
refined program. | ||
|
||
Inferring precise shapes and dtypes often is needed for correctness by | ||
backends. That said, requiring "optimizing hard enough" for correctness is | ||
usually considered quite brittle in a compiler flow. In this case, the saving | ||
grace is that we are only optimizing the functions, which are authored by | ||
compiler developers (not users), and thus there is some give-and-take in terms | ||
of understanding the optimizable constructs while authoring the functions, or | ||
improving the optimizations to enable easier authoring. Some brittleness is | ||
likely to escape to users, unfortunately, since there will always be situations | ||
where, for example, a statically shaped program allows the shape functions to be | ||
simplified to a greater extent than in a dynamically shaped program (for | ||
example, if the shape function checks "is this dimension of size 1"). We hope | ||
that this is minimal. | ||
|
||
## Adding to the calculations library | ||
|
||
See [Adding a Shape and Dtype Function](adding_a_shape_and_dtype_function.md) | ||
for details on how to add a shape and dtype function for an operator. | ||
|
||
## Rationale | ||
|
||
### Use of full operator signatures | ||
|
||
The use of the full operator signature such as | ||
`def aten〇add〇Tensor(self: List[int], other: List[int], alpha: float = 1) -> List[int]:` | ||
for defining calculation functions is somewhat verbose and repetitive, especially when | ||
there are multiple identical functions. Upstream uses a map with key-value | ||
pairs like `"aten.add.Tensor": upstream_shape_functions.broadcast`, which is | ||
more compact and less repetitive in some ways (upstream also allows trailing | ||
arguments beyond those accepted by the shape function to be ignored, allowing | ||
further deduplication). The decision to do it the more verbose way in Torch-MLIR | ||
was based on the following goals: | ||
|
||
- To make the system very easy to debug and test. | ||
|
||
- To make the system maximally consistent between functions that are | ||
implemented with the upstream shape helpers and the ones that are manually | ||
written, which are still a fairly large and non-trivial set. | ||
|
||
- To make it as mechanical as possible to add a new function. |
Oops, something went wrong.