Skip to content

Commit

Permalink
[custom op] Generalize shape library logic to work with dtypes
Browse files Browse the repository at this point in the history
This commit generalizes the shape library logic, so that dtype rules
for ops can also be expressed using the same mechanism. In other
words, each op can now have a shape function and a dtype function
specified in Python that is imported during lowering to calculate the
shapes and dtypes throught a program. For more information about how
to specify a dtype function, see the updated
`docs/adding_a_shape_and_dtype_function.md`.

For those not familiar with how the shape library works, the file
`docs/calculations_lib.md` provides an overview.

To make the reviewing a bit easier, I suggest the following review
order:

1. Get familiar with the overall architecture by reading
`docs/calculations_lib.md`
2. New op declarations
   - `include/torch-mlir/Dialect/Torch/IR/TorchOps.td`
   - `lib/Dialect/Torch/IR/TorchOps.cpp`
3. New passes
   - `include/torch-mlir/Dialect/Torch/Transforms/Passes.td`
   - `include/torch-mlir/Dialect/Torch/Transforms/Passes.h`
   - `lib/Dialect/Torch/Transforms/Passes.cpp`
   - `lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp`
   - `lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp`
   - `lib/Dialect/Torch/Transforms/ReifyCalculationsUtils.cpp`
   - `lib/Dialect/Torch/Transforms/ReifyCalculationsUtils.h`
   - `lib/Dialect/Torch/Transforms/SimplifyCalculationsUtils.cpp`
   - `lib/Dialect/Torch/Transforms/SimplifyCalculationsUtils.h`

   The `*Utils.*` files include logic that is shared by dtype and
   shape passes.
4. Tests
   - `test/Dialect/Torch/ops.mlir`
   - `test/Dialect/Torch/reify-dtype-calculations.mlir`
   - `test/Dialect/Torch/simplify-dtype-calculations.mlir`
5. Introduce `torch_mlir_promote_dtypes`
   - `python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp`
   - `python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py`
6. Simple refactoring/generalizing
   - `include/torch-mlir/Dialect/Torch/Utils/Utils.h`
   - `lib/Dialect/Torch/Utils/Utils.cpp`
   - `lib/Dialect/Torch/Transforms/DropCalculations.cpp`
   - `lib/Dialect/Torch/Transforms/RefineTypes.cpp`
   - `lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp`
   - `lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp`
   - `python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py`
   - `python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py`
   - `python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/calculations_lib_gen.py`
   - `lib/Dialect/Torch/Transforms/CalculationsLibrary.cpp`
7. The rest of the files include minor changes
   - Replace `shape` with `calculations`
   - Replace `m_TorchConstantIntList` with
   `m_TorchListOfConstantInts` (needed to avoid ambiguity with new
   pattern `m_TorchListOfOptionalConstantInts`)
  • Loading branch information
ramiro050 committed Nov 16, 2022
1 parent fc4c8d4 commit cd6fd6f
Show file tree
Hide file tree
Showing 56 changed files with 3,657 additions and 2,338 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/RollPyTorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ jobs:
echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV}
echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV}
- name: Build and test (in-tree), also update ODS and shape library
- name: Build and test (in-tree), also update ODS and calculations library
if: env.PT_HASH_CHANGED != '0'
run: |
cd ${GITHUB_WORKSPACE}
TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \
TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \
TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \
TM_UPDATE_ODS_AND_SHAPE_LIB="ON" \
TM_UPDATE_ODS_AND_CALCULATIONS_LIB="ON" \
./build_tools/python_deploy/build_linux_packages.sh
- name: Push changes to main branch
Expand All @@ -82,7 +82,7 @@ jobs:
git config user.name "Roll PyTorch Action"
git fetch --recurse-submodules=no
git checkout main
git add pytorch-hash.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/ShapeLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
git add pytorch-hash.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/CalculationsLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main)
- name: Update PyTorch Build Cache (if running on main branch)
Expand Down
14 changes: 7 additions & 7 deletions build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ TM_PACKAGES="${TM_PACKAGES:-torch-mlir}"
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
# Skip running tests if you want quick iteration
TM_SKIP_TESTS="${TM_SKIP_TESTS:-OFF}"
# Update ODS and shape library files
TM_UPDATE_ODS_AND_SHAPE_LIB="${TM_UPDATE_ODS_AND_SHAPE_LIB:-OFF}"
# Update ODS and calculations library files
TM_UPDATE_ODS_AND_CALCULATIONS_LIB="${TM_UPDATE_ODS_AND_CALCULATIONS_LIB:-OFF}"

PKG_VER_FILE="${repo_root}"/torch_mlir_package_version ; [ -f "$PKG_VER_FILE" ] && . "$PKG_VER_FILE"
TORCH_MLIR_PYTHON_PACKAGE_VERSION="${TORCH_MLIR_PYTHON_PACKAGE_VERSION:-0.0.1}"
Expand Down Expand Up @@ -119,7 +119,7 @@ function run_on_host() {
-e "TM_PYTHON_VERSIONS=${TM_PYTHON_VERSIONS}" \
-e "TM_PACKAGES=${package}" \
-e "TM_SKIP_TESTS=${TM_SKIP_TESTS}" \
-e "TM_UPDATE_ODS_AND_SHAPE_LIB=${TM_UPDATE_ODS_AND_SHAPE_LIB}" \
-e "TM_UPDATE_ODS_AND_CALCULATIONS_LIB=${TM_UPDATE_ODS_AND_CALCULATIONS_LIB}" \
-e "TM_USE_PYTORCH_BINARY=${TM_USE_PYTORCH_BINARY}" \
-e "TORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO}" \
-e "TORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH}" \
Expand Down Expand Up @@ -164,10 +164,10 @@ function run_in_docker() {
in-tree)
setup_venv "$python_version"
build_in_tree "$TM_USE_PYTORCH_BINARY" "$python_version"
if [ "${TM_UPDATE_ODS_AND_SHAPE_LIB}" == "ON" ]; then
if [ "${TM_UPDATE_ODS_AND_CALCULATIONS_LIB}" == "ON" ]; then
pushd /main_checkout/torch-mlir
./build_tools/update_torch_ods.sh
./build_tools/update_shape_lib.sh
./build_tools/update_calculations_lib.sh
popd
fi
if [ "${TM_SKIP_TESTS}" == "OFF" ]; then
Expand Down Expand Up @@ -253,8 +253,8 @@ function test_in_tree() {
cd /main_checkout/torch-mlir/
export PYTHONPATH="/main_checkout/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir"

echo ":::: Check that update_shape_lib.sh has been run"
_check_file_not_changed_by ./build_tools/update_shape_lib.sh lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
echo ":::: Check that update_calculations_lib.sh has been run"
_check_file_not_changed_by ./build_tools/update_calculations_lib.sh lib/Dialect/Torch/Transforms/CalculationsLibrary.cpp

echo ":::: Check that update_torch_ods.sh has been run"
_check_file_not_changed_by ./build_tools/update_torch_ods.sh include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Updates auto-generated shape library files for the `torch` dialect.
# Updates auto-generated calculations library files for the `torch` dialect.
#
# Environment variables:
# TORCH_MLIR_EXT_MODULES: comma-separated list of python module names
Expand Down Expand Up @@ -41,6 +41,6 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
fi

PYTHONPATH="${pypath}" python \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.calculations_lib_gen \
--pytorch_op_extensions=${ext_module:-""} \
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"
91 changes: 91 additions & 0 deletions docs/adding_a_shape_and_dtype_function.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Adding a Shape and Dtype Function

## Overview

As part of adding support for a Torch operator in Torch-MLIR, it is
usually necessary to define a shape and dtype function so that the
compiler can infer the shapes and dtypes of result tensors for the
operator. We use the [calculations library](calculations_lib.md) for
this process.

## Step-by-step guide

We will use the example of adding support for the `torch.aten.tanh` op.

1. First, you need to find the shape and dtype function signatures for
the operator you are implementing a functions for. This can be
found in
`include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt`
generated by the `build_tools/update_torch_ods.sh` script. That
file is the "rosetta stone" that allows translating between
e.g. `torch.aten.tanh`, `AtenTanhOp`, and the shape and dtype
function signatures are:

- `def aten〇tanh〡shape(self: List[int]) -> List[int]:`
- `def aten〇tan_〡dtype(self_rank: int, self_dtype: int) -> int:`

Note the use of `` as a separator since `.` or `::` aren't legal
in a Python identifier.

2. Paste the function signature into `calculations_lib_gen.py` in an
appropriate place (ideally near other functions with a similar
functions). Note that `calculations_lib_gen.py` will check that
these signatures are verbatim identical with the ones given in
`JITOperatorRegistryDump.txt` -- this ensures that the functions
don't get outdated if Torch changes an operator signature.

3. Fill in the body of the function. Ideally this will just be a call
into a helper function from
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1).
But in general, you will need to write the function and test it
(see the comments about "Shape, dtype, and decomposition function
testing infrastructure" in `testing_framework.py`). New shape
functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889),
though it can be useful to iterate locally in `calculations_lib_gen.py`
first.

4. Re-run the `build_tools/update_calculations_lib.sh` script to
update the library. After this step happens, ideally everything
"just works" and the functions are now correctly inferred for the
operator.

## When things go wrong

It is possible that the refinement pipeline (see [Shape and Dtype Refinement Pipeline Architecture](calculations_lib.md#shape-and-dtype-refinement-pipeline-architecture))
is not able to infer the shape/dtype of a tensor with a given
calculation function. This usually means that there is something about
the function which the optimizations in
`torch-simplify-shape/dtype-functions`
(`lib/Dialect/Torch/Transforms/Simplify{Shape/Dtype}Calculations.cpp`)
cannot handle.

To debug this, the overall goal is to pinpoint the IR construct that
is not being simplified. This is usually accomplished by a combination
of looking at the Python code for the function and the IR dumps. The
best IR dump to look at varies, but frequently the IR dump right
before `DropCalculations` is the most useful, because it has already
been simplified as much as possible, making it is easy to see what is
blocking further simplification. Examples of issues you might see:

- You might find that there is a loop with a non-constant trip count,
but based on your understanding of the function, you would expect it
to be simplified to a constant trip count -- you can then look at
the trip count calculation and see if there is a missing fold or
canonicalization.

- You might find that there is a list operation that is not currently understood
by the optimizations. You can then teach the optimizations about that
operation.

- You might find that there is an `Optional` value that you would
expect to be resolved to either a concrete value or `None`. You can
then look at the calculation that produces the optional value and
see what folds or canonicalizations are missing.

See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general
guidance on debugging Torch-MLIR.

As a last resort, you can rewrite the function using constructs that
`torch-simplify-shape/dtype-functions` can handle (look at other
functions for examples, sometimes it requires writing things a little
awkwardly).
75 changes: 0 additions & 75 deletions docs/adding_a_shape_function.md

This file was deleted.

126 changes: 126 additions & 0 deletions docs/calculations_lib.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Torch-MLIR Calculations Library Infrastructure

## Overview

The Torch-MLIR project has an infrastructure for maintaining a library of
calculation functions for different Torch operators, which supply extra
information such as result dtypes and shapes as well as decompositions. These
functions are fully executable specifications of the shape/dtype/decomposition
functions for each operator and can be authored and tested from Python for
convenience. These are then brought into the compiler and can be manipulated /
transformed for various purposes. Additionally, in the case of shape functions,
this effort is synergistic with upstream PyTorch efforts to maintain a library
of shape functions.

The two main use cases are:

- Refinement / inference. The `torch-shape/dtype-refinement-pipeline` pass
pipeline orchestrates a series of passes that use the available information in
the program to further refine the types in the program.

- Error guard insertion for backends (Not Yet Implemented). The executable
functions can include error guards / assertions that abort the program in case
of invalid input (such as a matmul with a mismatching contracting dimension).

## Architecture

Functions are defined as TorchScript-able Python functions in
`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/calculations_lib_gen.py`.
The signatures of the functions are systematically derived from Torch JIT
operator registry. Most shape functions are expected to reuse the upstream
helper functions
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1),
and any new shape functions should be added there.

The `build_tools/update_calculations_lib.sh` script invokes
`calculations_lib_gen.py` to generate an MLIR module containing the functions,
which is currently embedded as a string literal in
`lib/Dialect/Torch/Transforms/CalculationsLibrary.cpp`.

The function `StringRef mlir::torch::Torch::getCalculationsLibrary()` is
available for use inside the compiler any time that the library is needed.

## Shape and Dtype Refinement Pipeline Architecture

One of the main services that Torch-MLIR provides for backends is to normalize
all Torch frontends into a common form which includes tensor shapes and dtypes
that are as precise as possible. This alleviates the need for backends to solve
this problem themselves. This process of shape and dtype refinement is
accomplished in Torch-MLIR through a pipeline of passes which uses the
calculations library combined with abstract interpretation of the calculation
functions to calculate shapes and dtypes that are as precise as possible.

The pipeline works as follows:

1. Calculations are reified. The `torch-reify-shape/dtype-calculations` reifies
(i.e., materializes into the IR) the functions for each op with a function in
the calculation library. To do this, it wraps those ops in a
`torch.shape/dtype.calculate` op, which has two regions: 1) a body with the
op itself, and 2) the shape or dtype calculation, which calculates the shapes
or dtypes of the tensors yielded by the body.

2. Simplifying the functions and propagating the shapes and dtypes. After the
functions are reified, we then attempt to "optimize hard enough" until the
shapes and dtypes yielded by the calculation regions become obvious in the IR.
Those results are propagated through the IR, which usually reveals more
opportunities for simplification.

a. After reification, the functions are just a loose collection of
functions, which are difficult to analyze. The first step is to inline them.

b. After inlining, the `torch-simplify-shape/dtype-calculations` pass is used
to simplify the calculations. This pass brings in a number of targeted
canonicalization patterns and folds, along with a few specific patterns such
as unrolling fixed-trip-count loops and abstractly interpreting list
operations (an example is turning a series of "append" operations into a list
literal). This pass also looks at the values yielded by the calculation
regions, and if the resulting shape or dtype can be deduced by looking at the
IR (for example, the shape is the list literal `[1, 2, 3]`), it will refine
the types of the `torch.shape/dtype.calculate` op. This usually yields more
opportunities for simplification. This process runs to a fixed-point.

3. Dropping the calculations. Once all the types in the program have been
refined as much as possible, the ops that were originally wrapped in
`torch.shape/dtype.calculate` are unwrapped by the `torch-drop-calculations`
pass which drops the reified calculations, leaving behind the shape and dtype
refined program.

Inferring precise shapes and dtypes often is needed for correctness by
backends. That said, requiring "optimizing hard enough" for correctness is
usually considered quite brittle in a compiler flow. In this case, the saving
grace is that we are only optimizing the functions, which are authored by
compiler developers (not users), and thus there is some give-and-take in terms
of understanding the optimizable constructs while authoring the functions, or
improving the optimizations to enable easier authoring. Some brittleness is
likely to escape to users, unfortunately, since there will always be situations
where, for example, a statically shaped program allows the shape functions to be
simplified to a greater extent than in a dynamically shaped program (for
example, if the shape function checks "is this dimension of size 1"). We hope
that this is minimal.

## Adding to the calculations library

See [Adding a Shape and Dtype Function](adding_a_shape_and_dtype_function.md)
for details on how to add a shape and dtype function for an operator.

## Rationale

### Use of full operator signatures

The use of the full operator signature such as
`def aten〇add〇Tensor(self: List[int], other: List[int], alpha: float = 1) -> List[int]:`
for defining calculation functions is somewhat verbose and repetitive, especially when
there are multiple identical functions. Upstream uses a map with key-value
pairs like `"aten.add.Tensor": upstream_shape_functions.broadcast`, which is
more compact and less repetitive in some ways (upstream also allows trailing
arguments beyond those accepted by the shape function to be ignored, allowing
further deduplication). The decision to do it the more verbose way in Torch-MLIR
was based on the following goals:

- To make the system very easy to debug and test.

- To make the system maximally consistent between functions that are
implemented with the upstream shape helpers and the ones that are manually
written, which are still a fairly large and non-trivial set.

- To make it as mechanical as possible to add a new function.
Loading

0 comments on commit cd6fd6f

Please sign in to comment.