Skip to content

Commit

Permalink
[TVMScript] Symbolic shape computing (apache#342)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored Jan 8, 2023
1 parent 80808fb commit da11e4b
Show file tree
Hide file tree
Showing 18 changed files with 406 additions and 359 deletions.
12 changes: 0 additions & 12 deletions include/tvm/script/ir_builder/relax/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,6 @@ namespace script {
namespace ir_builder {
namespace relax {

//////////////////////////////// Tensor /////////////////////////////////

/*!
* \brief Create a TensorStructInfo.
* \param shape The shape of the tensor. It's runtime dependent if `shape` is None.
* \param dtype The element data type of the tensor. It's runtime dependent if `dtype` is None.
* \param ndim The number of dimensions of the tensor. It's runtime dependent if `ndim` is -1.
* \return The TensorStructInfo.
*/
TVM_DLL tvm::relax::TensorStructInfo Tensor(Optional<Array<PrimExpr>> shape, DataType dtype,
int ndim = -1);

/////////////////////////////// Function ////////////////////////////////

/*!
Expand Down
28 changes: 16 additions & 12 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,14 @@ def __init__(
struct_info: Optional[StructInfo] = None,
span: Span = None,
) -> None:
if struct_info is not None and not isinstance(struct_info, StructInfo):
raise TypeError(
"struct_info needs to be an instance of StructInfo. "
"If you attempt to pass in shape, "
"use relax.TensorStructInfo(shape, dtype)."
)
if struct_info is not None:
struct_info = tvm.runtime.convert_to_object(struct_info)
if not isinstance(struct_info, StructInfo):
raise TypeError(
"struct_info needs to be an instance of StructInfo. "
"If you attempt to pass in shape, "
"use relax.TensorStructInfo(shape, dtype)."
)
self.__init_handle_by_constructor__(
_ffi_api.Var if isinstance(name_hint, str) else _ffi_api.VarFromId, # type: ignore
name_hint,
Expand Down Expand Up @@ -284,12 +286,14 @@ def __init__(
struct_info: Optional[StructInfo] = None,
span: Span = None,
) -> None:
if struct_info is not None and not isinstance(struct_info, StructInfo):
raise TypeError(
"struct_info needs to be an instance of StructInfo. "
"If you attempt to pass in shape, "
"use relax.TensorStructInfo(shape, dtype)."
)
if struct_info is not None:
struct_info = tvm.runtime.convert_to_object(struct_info)
if not isinstance(struct_info, StructInfo):
raise TypeError(
"struct_info needs to be an instance of StructInfo. "
"If you attempt to pass in shape, "
"use relax.TensorStructInfo(shape, dtype)."
)

self.__init_handle_by_constructor__(
_ffi_api.DataflowVar # type: ignore
Expand Down
102 changes: 31 additions & 71 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
"""IRBuilder for Relax dialect"""

import functools
import inspect
from typing import Dict, List, Optional, Tuple, Union

import tvm
from tvm.ir import Type
from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, TupleType, Var, const
from tvm.relax.struct_info import StructInfo, TensorStructInfo
from tvm import relax
from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, const
from tvm.relax.struct_info import StructInfo
from tvm.relax.analysis import get_static_type

############################### Operators ###############################
Expand Down Expand Up @@ -75,50 +77,15 @@
zeros_like,
)
from tvm.relax.utils import convert_to_expr
from tvm.runtime import Object as tvm_Object
from tvm.tir import PrimExpr
from tvm.runtime import Object as tvm_Object, ObjectGeneric

from ..tir import var as _tir_var
from . import _ffi_api, frame

############################## Tensor Type ##############################
##################### Python Native Function Alias ######################

py_print = print
py_tuple = tuple

def tensor(
shape: Optional[List[Union[PrimExpr, str]]] = None,
dtype: Optional[str] = None,
ndim: int = -1,
) -> TensorStructInfo:
"""Helper function for `R.Tensor` in parser
Parameters
----------
shape: Optional[List[Union[PrimExpr, str]]]
The shape of the tensor. It's runtime dependent if `shape` is None.
dtype: Optional[str]
The element data type of the tensor. It's runtime dependent if `dtype` is None.
ndim: int
The number of dimensions of the tensor. It's runtime dependent if `ndim` is -1.
Returns
-------
ret: TensorStructInfo
The result TensorStructInfo
"""

if shape is not None:
if not isinstance(shape, list):
shape = list(shape)

for i, s in enumerate(shape):
if isinstance(s, str):
shape[i] = _tir_var("int64", s)

return _ffi_api.Tensor(shape, dtype, ndim) # pylint: disable=no-member # type: ignore


############################## Other Types ##############################

Object = tvm.relax.ObjectStructInfo() # pylint: disable=invalid-name
Void = TupleType([]) # pylint: disable=invalid-name

############################### Function ################################

Expand Down Expand Up @@ -244,13 +211,16 @@ def call_packed(
args = [convert_to_expr(arg) for arg in args]
if type_args is None:
raise ValueError("R.call_packed is required to have type_args")
if isinstance(type_args, tuple):
if isinstance(type_args, py_tuple):
type_args = list(type_args)
elif not isinstance(type_args, list):
type_args = [type_args]
for i, argument in enumerate(type_args):
if callable(argument):
argument = argument()
# Convert possible StructInfoProxy to StructInfo
if isinstance(argument, ObjectGeneric):
argument = argument.asobject()
if isinstance(argument, StructInfo):
type_args[i] = get_static_type(argument)
elif isinstance(argument, Type):
Expand Down Expand Up @@ -279,11 +249,15 @@ def _tensor_type_wrapper(func):
"""A wrapper to convert StructInfo to relax.DynTensorType"""

def _convert_tensor_type(args):
if isinstance(args, (list, tuple)):
if isinstance(args, (list, py_tuple)):
new_args = [_convert_tensor_type(x) for x in args]
return type(args)(new_args)
if isinstance(args, dict):
return {_convert_tensor_type(k): _convert_tensor_type(v) for k, v in args.items()}
if inspect.isfunction(args):
args = args()
if isinstance(args, ObjectGeneric):
args = args.asobject()
return get_static_type(args) if isinstance(args, StructInfo) else args

@functools.wraps(func)
Expand Down Expand Up @@ -373,47 +347,33 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name
return _ffi_api.Else() # pylint: disable=no-member # type: ignore


######################## Symbolic Shape Rewriter ########################


def RewriteSymbolicShape(
struct_info: StructInfo,
var_table: Dict[str, tvm.tir.Var],
) -> Tuple[StructInfo, List[tvm.tir.Var]]:
"""Helper function to rewrite symbolic shape
This function remaps the symbolic shape by
mapping certain vars to new variables.
struct_info: StructInfo
The input struct info
############################### R.tuple ################################

var_table: Dict[str, tvm.tir.Var]
Dictionary to map name of var to a new var.

def tuple(*fields: List[Expr]) -> Expr:
"""Create a tuple expression.
Parameters
----------
fields : List[Expr]
The fields of the tuple.
Returns
-------
rewritten_info : StructInfo
The rewritten StructInfo
undefined_vars: List[tvm.tir.Var]
List of undefined vars.
res : Expr
The result tuple.
"""
return _ffi_api.RewriteSymbolicShape(
struct_info, var_table
) # pylint: disable=no-member # type: ignore
if len(fields) == 0:
fields = []

return relax.Tuple(fields) # pylint: disable=no-member # type: ignore


############################### Importer ###############################

__all__ = [
"Else",
"If",
"Object",
"RewriteSymbolicShape",
"Then",
"TupleGetItem",
"Void",
"add",
"arg",
"assert_op",
Expand Down Expand Up @@ -466,9 +426,9 @@ def RewriteSymbolicShape(
"subtract",
"sum",
"tanh",
"tensor",
"tril",
"triu",
"tuple",
"variance",
"zeros",
"zeros_like",
Expand Down
28 changes: 0 additions & 28 deletions python/tvm/script/parser/ir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""The base parser for ir module"""
from typing import Optional, Tuple

from tvm.ir import PrimExpr, PrimType, RelayExpr, Type
from ...ir_builder import ir as I
from .._core import Parser, dispatch, doc


def eval_func_type_shape(
self: Parser, node: doc.FunctionDef
) -> Tuple[Optional[Type], Optional[RelayExpr]]:
"""evaluate function type and shape.
Parameters
----------
self : Parser
The visiting parser.
node : doc.FunctionDef
The doc FunctionDef node.
"""
token = self.get_dispatch_token(node)
with self.with_dispatch_token(token):
result = self.visit_tvm_annotation(node.returns)
if result is None:
return None, None
elif isinstance(result, tuple) and len(result) == 2:
# relax dialect
return result
elif isinstance(result, PrimExpr):
# tir dialect
return PrimType(result.dtype), None
else:
raise TypeError(f"Unsupported annotation type: {result}")


@dispatch.register(token="ir", type_name="ClassDef")
def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
"""The class definition visiting method for ir module.
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/script/parser/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from ...ir_builder.relax import * # pylint: disable=redefined-builtin
from ...ir_builder.relax import ir as _relax
from . import parser as _parser
from .entry import Callable, Shape, Tensor, Tuple, function, match_cast
from .entry import Callable, Object, Shape, Tensor, Tuple, function, match_cast

__all__ = _relax.__all__ + [
"Callable",
"Object",
"Shape",
"Tensor",
"Tuple",
Expand Down
Loading

0 comments on commit da11e4b

Please sign in to comment.