Skip to content

Commit

Permalink
Add extra_library kwarg to torch_mlir.compile (llvm#1986)
Browse files Browse the repository at this point in the history
This commit adds the ability to specify extra abstract interpretation
functions in `torch_mlir.compile` to use during type refinement. This
allows users to easily add custom ops without having to interact with
MLIR or C++ directly.
  • Loading branch information
ramiro050 authored and gpetters94 committed Jul 7, 2023
1 parent 377c223 commit 9e12273
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 27 deletions.
29 changes: 22 additions & 7 deletions python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from typing import Optional, Sequence, Union, List, Dict, Tuple
from typing import Optional, Sequence, Union, List, Dict, Tuple, Callable, Iterable
from enum import Enum

import sys
from io import StringIO
import tempfile

from torch._functorch.compile_utils import strip_overloads
import torch

from torch_mlir.passmanager import PassManager
from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library


class OutputType(Enum):
Expand Down Expand Up @@ -252,7 +254,7 @@ def compile(model: torch.nn.Module,
use_tracing: bool = False,
ignore_traced_shapes=False,
backend_legal_ops: Optional[Sequence[str]] = None,
_completely_unsupported_in_progress_extra_library: Optional[str] = None,
extra_library: Iterable[Callable] = [],
verbose: bool = False):
"""Convert a PyTorch model to MLIR.
Expand All @@ -278,12 +280,28 @@ def compile(model: torch.nn.Module,
backend_legal_ops: A list of ops that should be considered legal for
the backend. An op that is considered legal will not be decomposed.
This option is only valid with the `"torch"` output type.
extra_library: List of abstract interpretation functions to splice
into the abstract interpretation library. See
`docs/adding_abstract_interpretation_functions.md` for more info
on the format the functions should have.
verbose: If true, print extra information about the conversion.
Returns:
An MLIR module that contains the converted model in the specified
output type.
"""
extra_library_file_name = ""
if len(extra_library) != 0:
extra_library_dict = {}
for library_func in extra_library:
extra_library_dict[library_func.__name__] = library_func
mlir_library = generate_library(extra_library_dict)

extra_library_file_name = \
tempfile.gettempdir() + "/custom_op_extra_library.mlir"
with open(extra_library_file_name, "w") as f:
f.write(mlir_library)

output_type = OutputType.get(output_type)
example_args = ExampleArgs.get(example_args)
if ignore_traced_shapes and not use_tracing:
Expand Down Expand Up @@ -368,11 +386,8 @@ def compile(model: torch.nn.Module,
if output_type == OutputType.RAW:
return mb.module

option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + (
(" extra-library=" + _completely_unsupported_in_progress_extra_library)
if (_completely_unsupported_in_progress_extra_library is not None)
else ""
) + "}"
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
" extra-library=" + extra_library_file_name + "}"
run_pipeline_with_repro_report(
mb.module,
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import inspect
import re
from typing import List, Optional, Union
from typing import List, Optional, Union, Any, Dict

import torch

Expand Down Expand Up @@ -138,25 +138,30 @@ def _verify_signature_matches_registry(f, registry: Registry):
atoms = function_name.split("〇")
if len(atoms) == 2:
atoms += [""]
operator = registry.get_by_triple(tuple(atoms))
try:
operator = registry.get_by_triple(tuple(atoms))
except KeyError as e:
raise ValueError(f"Unable to find op {'.'.join(atoms)} in registry")
if function_kind == "shape":
expected_signature = operator.get_shape_function_signature()
elif function_kind == "dtype":
expected_signature = operator.get_dtype_function_signature()
elif function_kind == "decomposition":
expected_signature = operator.get_decomposition_function_signature()
elif function_kind == "has_value_semantics":
expected_signature = operator.get_has_value_semantics_function_signature()
else:
raise ValueError(f"Invalid Op signature function kind: '{function_kind}'")
if signature != expected_signature:
raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}")

def generate_library(globals_) -> str:
"""Convert all op functions in `globals()` into MLIR."""
def generate_library(functions: Dict[str, Any]) -> str:
"""Convert all op functions in `functions` into MLIR."""
mb = ModuleBuilder()
# We use the registry to ensure that the shape functions are consistent
# with the ops.
registry = Registry.load()
for k, v in globals_.items():
for k, v in functions.items():
if "〇" not in k:
continue
if not hasattr(v, "_not_present_in_registry"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,23 @@ def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
return self._get_function_signature(
"decomposition", parameter_decl_builder, ret_decl_builder)

def get_has_value_semantics_function_signature(self):
"""Gets the Python function signature for this op's has_value_semantics function.
While this is technically debug-only output, it is useful to copy-paste
it from the debug dump into the library definitions, as many
ops have extra default arguments and stuff that are tedious to write out
right.
"""
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
return ""

def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
return "None"

return self._get_function_signature(
"has_value_semantics", parameter_decl_builder, ret_decl_builder)

def __repr__(self):
f = io.StringIO()
emitter = TextEmitter(f)
Expand Down
29 changes: 14 additions & 15 deletions test/python/custom_op_shape_dtype_fn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import tempfile
from typing import List, Tuple

import torch
import torch.utils.cpp_extension
Expand All @@ -18,6 +19,18 @@ def identity(x: torch.Tensor):
goofy_lib.define("identity(Tensor t) -> Tensor")
goofy_lib.impl("identity", identity)

def goofy〇identity〡shape(t: List[int]) -> List[int]:
return t

def goofy〇identity〡dtype(t_rank_dtype: Tuple[int, int]) -> int:
t_rank, t_dtype = t_rank_dtype
return t_dtype

def goofy〇identity〡has_value_semantics() -> None:
return

extra_library = [
goofy〇identity〡shape, goofy〇identity〡dtype, goofy〇identity〡has_value_semantics]

class CustomOpExampleModule(torch.nn.Module):
def __init__(self):
Expand All @@ -38,26 +51,12 @@ def forward(self, a):
mod = CustomOpExampleModule()
mod.eval()

abstract_interp_src = """\
func.func @__torch_mlir_shape_fn.goofy.identity(%arg0: !torch.list<int>) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
func.func @__torch_mlir_dtype_fn.goofy.identity(%arg0 : !torch.tuple<int, int>) -> !torch.int {
%0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int
return %0#1 : !torch.int
}
func.func @__torch_mlir_has_value_semantics_fn.goofy.identity() { return }
"""

with open("/tmp/custom_op_shape_dtype_fn.mlir", "w") as tmp:
tmp.write(abstract_interp_src)

module = torch_mlir.compile(
mod,
torch.ones(3, 4),
output_type="torch",
backend_legal_ops=["goofy.identity"],
_completely_unsupported_in_progress_extra_library="/tmp/custom_op_shape_dtype_fn.mlir",
extra_library=extra_library,
)

print(module)
Expand Down

0 comments on commit 9e12273

Please sign in to comment.