diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index e7a5947d97b6..72aab6684ebf 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -29,18 +29,6 @@ namespace script { namespace ir_builder { namespace relax { -//////////////////////////////// Tensor ///////////////////////////////// - -/*! - * \brief Create a TensorStructInfo. - * \param shape The shape of the tensor. It's runtime dependent if `shape` is None. - * \param dtype The element data type of the tensor. It's runtime dependent if `dtype` is None. - * \param ndim The number of dimensions of the tensor. It's runtime dependent if `ndim` is -1. - * \return The TensorStructInfo. - */ -TVM_DLL tvm::relax::TensorStructInfo Tensor(Optional> shape, DataType dtype, - int ndim = -1); - /////////////////////////////// Function //////////////////////////////// /*! diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 35f388dc6d5b..a1c69635ff2b 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -231,12 +231,14 @@ def __init__( struct_info: Optional[StructInfo] = None, span: Span = None, ) -> None: - if struct_info is not None and not isinstance(struct_info, StructInfo): - raise TypeError( - "struct_info needs to be an instance of StructInfo. " - "If you attempt to pass in shape, " - "use relax.TensorStructInfo(shape, dtype)." - ) + if struct_info is not None: + struct_info = tvm.runtime.convert_to_object(struct_info) + if not isinstance(struct_info, StructInfo): + raise TypeError( + "struct_info needs to be an instance of StructInfo. " + "If you attempt to pass in shape, " + "use relax.TensorStructInfo(shape, dtype)." + ) self.__init_handle_by_constructor__( _ffi_api.Var if isinstance(name_hint, str) else _ffi_api.VarFromId, # type: ignore name_hint, @@ -284,12 +286,14 @@ def __init__( struct_info: Optional[StructInfo] = None, span: Span = None, ) -> None: - if struct_info is not None and not isinstance(struct_info, StructInfo): - raise TypeError( - "struct_info needs to be an instance of StructInfo. " - "If you attempt to pass in shape, " - "use relax.TensorStructInfo(shape, dtype)." - ) + if struct_info is not None: + struct_info = tvm.runtime.convert_to_object(struct_info) + if not isinstance(struct_info, StructInfo): + raise TypeError( + "struct_info needs to be an instance of StructInfo. " + "If you attempt to pass in shape, " + "use relax.TensorStructInfo(shape, dtype)." + ) self.__init_handle_by_constructor__( _ffi_api.DataflowVar # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index b5a17b7166fc..976da6c9bc40 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -18,12 +18,14 @@ """IRBuilder for Relax dialect""" import functools +import inspect from typing import Dict, List, Optional, Tuple, Union import tvm from tvm.ir import Type -from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, TupleType, Var, const -from tvm.relax.struct_info import StructInfo, TensorStructInfo +from tvm import relax +from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, const +from tvm.relax.struct_info import StructInfo from tvm.relax.analysis import get_static_type ############################### Operators ############################### @@ -75,50 +77,15 @@ zeros_like, ) from tvm.relax.utils import convert_to_expr -from tvm.runtime import Object as tvm_Object -from tvm.tir import PrimExpr +from tvm.runtime import Object as tvm_Object, ObjectGeneric -from ..tir import var as _tir_var from . import _ffi_api, frame -############################## Tensor Type ############################## +##################### Python Native Function Alias ###################### +py_print = print +py_tuple = tuple -def tensor( - shape: Optional[List[Union[PrimExpr, str]]] = None, - dtype: Optional[str] = None, - ndim: int = -1, -) -> TensorStructInfo: - """Helper function for `R.Tensor` in parser - Parameters - ---------- - shape: Optional[List[Union[PrimExpr, str]]] - The shape of the tensor. It's runtime dependent if `shape` is None. - dtype: Optional[str] - The element data type of the tensor. It's runtime dependent if `dtype` is None. - ndim: int - The number of dimensions of the tensor. It's runtime dependent if `ndim` is -1. - Returns - ------- - ret: TensorStructInfo - The result TensorStructInfo - """ - - if shape is not None: - if not isinstance(shape, list): - shape = list(shape) - - for i, s in enumerate(shape): - if isinstance(s, str): - shape[i] = _tir_var("int64", s) - - return _ffi_api.Tensor(shape, dtype, ndim) # pylint: disable=no-member # type: ignore - - -############################## Other Types ############################## - -Object = tvm.relax.ObjectStructInfo() # pylint: disable=invalid-name -Void = TupleType([]) # pylint: disable=invalid-name ############################### Function ################################ @@ -244,13 +211,16 @@ def call_packed( args = [convert_to_expr(arg) for arg in args] if type_args is None: raise ValueError("R.call_packed is required to have type_args") - if isinstance(type_args, tuple): + if isinstance(type_args, py_tuple): type_args = list(type_args) elif not isinstance(type_args, list): type_args = [type_args] for i, argument in enumerate(type_args): if callable(argument): argument = argument() + # Convert possible StructInfoProxy to StructInfo + if isinstance(argument, ObjectGeneric): + argument = argument.asobject() if isinstance(argument, StructInfo): type_args[i] = get_static_type(argument) elif isinstance(argument, Type): @@ -279,11 +249,15 @@ def _tensor_type_wrapper(func): """A wrapper to convert StructInfo to relax.DynTensorType""" def _convert_tensor_type(args): - if isinstance(args, (list, tuple)): + if isinstance(args, (list, py_tuple)): new_args = [_convert_tensor_type(x) for x in args] return type(args)(new_args) if isinstance(args, dict): return {_convert_tensor_type(k): _convert_tensor_type(v) for k, v in args.items()} + if inspect.isfunction(args): + args = args() + if isinstance(args, ObjectGeneric): + args = args.asobject() return get_static_type(args) if isinstance(args, StructInfo) else args @functools.wraps(func) @@ -373,35 +347,24 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name return _ffi_api.Else() # pylint: disable=no-member # type: ignore -######################## Symbolic Shape Rewriter ######################## - - -def RewriteSymbolicShape( - struct_info: StructInfo, - var_table: Dict[str, tvm.tir.Var], -) -> Tuple[StructInfo, List[tvm.tir.Var]]: - """Helper function to rewrite symbolic shape - - This function remaps the symbolic shape by - mapping certain vars to new variables. - - struct_info: StructInfo - The input struct info +############################### R.tuple ################################ - var_table: Dict[str, tvm.tir.Var] - Dictionary to map name of var to a new var. +def tuple(*fields: List[Expr]) -> Expr: + """Create a tuple expression. + Parameters + ---------- + fields : List[Expr] + The fields of the tuple. Returns ------- - rewritten_info : StructInfo - The rewritten StructInfo - - undefined_vars: List[tvm.tir.Var] - List of undefined vars. + res : Expr + The result tuple. """ - return _ffi_api.RewriteSymbolicShape( - struct_info, var_table - ) # pylint: disable=no-member # type: ignore + if len(fields) == 0: + fields = [] + + return relax.Tuple(fields) # pylint: disable=no-member # type: ignore ############################### Importer ############################### @@ -409,11 +372,8 @@ def RewriteSymbolicShape( __all__ = [ "Else", "If", - "Object", - "RewriteSymbolicShape", "Then", "TupleGetItem", - "Void", "add", "arg", "assert_op", @@ -466,9 +426,9 @@ def RewriteSymbolicShape( "subtract", "sum", "tanh", - "tensor", "tril", "triu", + "tuple", "variance", "zeros", "zeros_like", diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index 1b29e688764a..fb724181d044 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -15,39 +15,11 @@ # specific language governing permissions and limitations # under the License. """The base parser for ir module""" -from typing import Optional, Tuple -from tvm.ir import PrimExpr, PrimType, RelayExpr, Type from ...ir_builder import ir as I from .._core import Parser, dispatch, doc -def eval_func_type_shape( - self: Parser, node: doc.FunctionDef -) -> Tuple[Optional[Type], Optional[RelayExpr]]: - """evaluate function type and shape. - Parameters - ---------- - self : Parser - The visiting parser. - node : doc.FunctionDef - The doc FunctionDef node. - """ - token = self.get_dispatch_token(node) - with self.with_dispatch_token(token): - result = self.visit_tvm_annotation(node.returns) - if result is None: - return None, None - elif isinstance(result, tuple) and len(result) == 2: - # relax dialect - return result - elif isinstance(result, PrimExpr): - # tir dialect - return PrimType(result.dtype), None - else: - raise TypeError(f"Unsupported annotation type: {result}") - - @dispatch.register(token="ir", type_name="ClassDef") def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: """The class definition visiting method for ir module. diff --git a/python/tvm/script/parser/relax/__init__.py b/python/tvm/script/parser/relax/__init__.py index 53bc3b3626ba..648f2f337f91 100644 --- a/python/tvm/script/parser/relax/__init__.py +++ b/python/tvm/script/parser/relax/__init__.py @@ -18,10 +18,11 @@ from ...ir_builder.relax import * # pylint: disable=redefined-builtin from ...ir_builder.relax import ir as _relax from . import parser as _parser -from .entry import Callable, Shape, Tensor, Tuple, function, match_cast +from .entry import Callable, Object, Shape, Tensor, Tuple, function, match_cast __all__ = _relax.__all__ + [ "Callable", + "Object", "Shape", "Tensor", "Tuple", diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index f433ccaf45e2..48fa6f78f7d7 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -16,22 +16,28 @@ # under the License. # pylint: disable=missing-docstring, invalid-name import inspect +from typing import Any from typing import Callable as _Callable -from typing import List, Optional, Tuple -from typing import TypeVar as _TypeVar -from typing import Union - -from tvm import relax -from tvm.relax import DynTensorType, Expr, Function, StructInfo -from tvm.relax import Tuple as RxTuple -from tvm.relax import Type, Var +from typing import Dict, List, Optional, Set, TypeVar, Union + +from tvm.relax import ( + Expr, + FuncStructInfo, + Function, + ObjectStructInfo, + ShapeStructInfo, + StructInfo, + TensorStructInfo, + TupleStructInfo, +) from tvm.runtime import ObjectGeneric from tvm.tir import PrimExpr -from ...ir_builder.relax import tensor from .._core import parse, utils -FType = _TypeVar("FType", bound=_Callable) +FType = TypeVar("FType", bound=_Callable) + +############################## R.function ############################## def function(f: FType) -> Union[Function, FType]: @@ -45,40 +51,89 @@ def function(f: FType) -> Union[Function, FType]: setattr(function, "dispatch_token", "relax") +############################# Struct Info ############################## + + +class StructInfoProxy(ObjectGeneric): + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> StructInfo: + raise NotImplementedError() + + def get_symbolic_vars(self) -> Set[str]: + return {} + + def asobject(self): + return self.as_struct_info(None) + + ############################### R.Tensor ############################### -class TensorProxy(ObjectGeneric): - def __call__( +def _eval_shape(expr: Union[str, PrimExpr], dict_globals: Optional[Dict[str, Any]]) -> PrimExpr: + if isinstance(expr, str): + code = compile(expr, "", "eval") + return eval(code, dict_globals or {}) # pylint: disable=eval-used + else: + return expr + + +class TensorProxy(StructInfoProxy): + shape: Optional[List[Union[str, PrimExpr]]] + dtype: str + ndim: int + + def __init__( self, shape: Optional[List[Union[PrimExpr, str]]] = None, - dtype: str = None, + dtype: Optional[str] = None, ndim: int = -1, - ) -> relax.TensorStructInfo: - # scalar tensor case - if shape is not None and len(shape) == 0: - shape = [] - if isinstance(shape, str) and dtype is None: - dtype = shape - shape = None - return tensor(shape, dtype, ndim) - - def __getitem__(self, keys) -> Var: - return self(*keys) # pylint: disable=no-member # type: ignore - - def asobject(self): - """Convert to object when direct call `R.Tensor` - e.g. `x = R.invoke_closure(clo, (y,), type_args=R.Tensor)` - """ - return DynTensorType() + ) -> None: + self.shape = shape + self.dtype = dtype + self.ndim = ndim + super().__init__() + + def get_symbolic_vars(self) -> Set[str]: + if self.shape is None: + return {} + else: + return {s for s in self.shape if isinstance(s, str) and s.isidentifier()} + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo: + if self.shape is None: + return TensorStructInfo(None, self.dtype, self.ndim) + else: + if dict_globals is None and any([isinstance(s, str) for s in self.shape]): + raise ValueError( + "String-defined shape expr is only allowed when parsing function parameters " + "and return annotations for TVMScript." + ) + shape = [_eval_shape(s, dict_globals) for s in self.shape] + return TensorStructInfo(shape, self.dtype, self.ndim) + + +def Tensor( + shape: Optional[List[Union[PrimExpr, str]]] = None, + dtype: Optional[str] = None, + ndim: int = -1, +) -> TensorProxy: + # scalar tensor case + if shape is not None and len(shape) == 0: + shape = [] + if isinstance(shape, str) and dtype is None: + dtype = shape + shape = None + + if shape is not None and not isinstance(shape, (tuple, list)): + raise ValueError(f"shape must be a list or tuple, but got: {shape}") + return TensorProxy(shape, dtype, ndim) -Tensor = TensorProxy() # pylint: disable=invalid-name ############################## R.Callable ############################## -class CallableProxy: +class CallableProxy(StructInfoProxy): + params: List[StructInfoProxy] + ret: StructInfoProxy """Function type. A function type consists of a list of type parameters to enable @@ -88,75 +143,81 @@ class CallableProxy: Parameters ---------- - params : List[StructInfo] - The argument StructInfo + params : List[StructInfoProxy] + The argument StructInfoProxy - ret : StructInfo - The return StructInfo. + ret : StructInfoProxy + The return StructInfoProxy. """ - def __call__( + def __init__( self, - params: Union[StructInfo, List[StructInfo], Tuple[StructInfo]], - ret: StructInfo, - ) -> relax.FuncStructInfo: + params: Union[StructInfoProxy, List[StructInfoProxy]], + ret: StructInfoProxy, + ) -> None: if not isinstance(params, (list, tuple)): params = [params] - return relax.FuncStructInfo(params, ret) + # convert `R.Tensor` to `R.Tensor()` + self.params = [param() if callable(param) else param for param in params] + self.ret = ret() if callable(ret) else ret - def __getitem__(self, keys) -> Var: - return self(*keys) # pylint: disable=no-member # type: ignore + def get_symbolic_vars(self) -> Set[str]: + return set().union(*[p.get_symbolic_vars() for p in self.params]) + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo: + params = [param.as_struct_info(dict_globals) for param in self.params] + ret = self.ret.as_struct_info(dict_globals) + return FuncStructInfo(params, ret) + + +def Callable( + params: Union[StructInfoProxy, List[StructInfoProxy]], + ret: StructInfoProxy, +) -> CallableProxy: + return CallableProxy(params, ret) -Callable = CallableProxy() ############################### R.Tuple ################################ -class TupleProxy: +class TupleProxy(StructInfoProxy): + fields: List[StructInfoProxy] """The type of tuple values. Parameters ---------- - fields : List[Union[Expr, Type, StructInfo]] + fields : List[StructInfoProxy] The fields in the tuple """ - def __call__( + def __init__( self, - *fields: List[Union[Expr, Type, StructInfo]], - ) -> Union[Expr, StructInfo]: + *fields: List[StructInfoProxy], + ) -> None: if len(fields) == 1 and isinstance(fields[0], (tuple, list)): fields = fields[0] + # convert `R.Tensor` to `R.Tensor()` + self.fields = [field() if callable(field) else field for field in fields] - if len(fields) == 0: - # Note: We cannot detect it's an expr or a struct info for empty tuple. - # So we return an expr by default, and use a spacial case in parser to fix it. - return RxTuple([]) + def get_symbolic_vars(self) -> Set[str]: + return set().union(*[f.get_symbolic_vars() for f in self.fields]) - if all([isinstance(f, Expr) for f in fields]): - return RxTuple(fields) - else: - fields = list(fields) - for i, x in enumerate(fields): - if callable(x): - fields[i] = x() - if all([isinstance(f, StructInfo) for f in fields]): - return relax.TupleStructInfo(fields) - else: - raise TypeError(f"Invalid tuple type: {fields}") + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TupleStructInfo: + fields = [field.as_struct_info(dict_globals) for field in self.fields] + return TupleStructInfo(fields) - def __getitem__(self, keys) -> Var: - return self(*keys) # pylint: disable=no-member # type: ignore +def Tuple(*fields: List[StructInfoProxy]) -> TupleProxy: + return TupleProxy(*fields) -Tuple = TupleProxy() ############################### R.Shape ################################ -class ShapeProxy: +class ShapeProxy(StructInfoProxy): + values: Optional[List[PrimExpr]] + ndim: int """The type of shape values. Parameters @@ -168,18 +229,56 @@ class ShapeProxy: The size of the shape. """ - def __call__( + def __init__( self, values: Optional[List[PrimExpr]] = None, ndim: int = -1, - ) -> StructInfo: - return relax.ShapeStructInfo(values, ndim) + ) -> None: + self.values = values + self.ndim = ndim + + def get_symbolic_vars(self) -> Set[str]: + if self.values is None: + return {} + else: + return {v for v in self.values if isinstance(v, str) and v.isidentifier()} + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: + values = [_eval_shape(v, dict_globals) for v in self.values] if self.values else None + return ShapeStructInfo(values, self.ndim) + + +def Shape(values: Optional[List[PrimExpr]] = None, ndim: int = -1) -> ShapeProxy: + return ShapeProxy(values, ndim) + + +############################### R.Object ################################ + + +class ObjectProxy(StructInfoProxy): + """The proxy fo ObjectStructInfo. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + """ + + def __init__(self) -> None: + pass + + def get_symbolic_vars(self) -> Set[str]: + return set() - def __getitem__(self, keys) -> Var: - return self(*keys) # pylint: disable=no-member # type: ignore + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: + return ObjectStructInfo() -Shape = ShapeProxy() +def Object() -> ObjectProxy: + return ObjectProxy() ############################ R.match_cast ############################# diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index d48b05fa95ba..5873f265c833 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -18,11 +18,11 @@ import functools import numbers -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional from tvm import relax, tir -from tvm.ir import Type, structural_equal -from tvm.relax import Expr, StructInfo +from tvm.ir import structural_equal +from tvm.relax import StructInfo from tvm.relax.utils import convert_to_expr from tvm.script.ir_builder.relax.frame import BlockFrame @@ -30,7 +30,7 @@ from ...ir_builder import relax as R from ...ir_builder.base import IRBuilder from .._core import Parser, dispatch, doc -from .entry import MatchCastPair +from .entry import MatchCastPair, StructInfoProxy def bind_assign_value( @@ -87,56 +87,76 @@ def bind_assign_value( return var -# pylint: disable=inconsistent-return-statements -def eval_type_annotation( - self: Parser, node: Union[doc.Expression, doc.expr] -) -> Tuple[Type, Optional[Expr], StructInfo]: - annotation = self.eval_expr(node) - if callable(annotation): - annotation = annotation() - - if isinstance(annotation, relax.Tuple) and len(annotation.fields) == 0: - # Case for empty tuple - annotation = relax.TupleStructInfo([]) - - if isinstance(annotation, StructInfo): - var_table = {k: v for k, v in self.var_table.get().items() if isinstance(v, tir.Var)} - annotation, undefined_vars = R.RewriteSymbolicShape(annotation, var_table) - for var in undefined_vars: - self.var_table.add(var.name, var) - return annotation - else: - self.report_error(node, f"Unsupported type annotation {annotation}") +def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy: + try: + annotation = self.eval_expr(node) + if callable(annotation): + annotation = annotation() + if isinstance(annotation, StructInfoProxy): + return annotation + else: + raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.") + except Exception as err: + self.report_error(node, str(err)) + raise err + + +def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> StructInfo: + var_table = self.var_table.get() if eval_str else None + try: + return eval_struct_info_proxy(self, node).as_struct_info(var_table) + except Exception as err: + self.report_error(node, str(err)) + raise err + + +def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None: + # Collect symbolic vars from parameters + symbolic_vars = set() + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation) + symbolic_vars.update(param_sinfo_proxy.get_symbolic_vars()) + + # Define symbolic vars to the current var_table frame + for var_name in symbolic_vars: + self.var_table.add(var_name, tir.Var(var_name, "int64"), allow_shadowing=False) @dispatch.register(token="relax", type_name="FunctionDef") def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: with self.var_table.with_frame(): - with R.function(): - R.func_name(node.name) - if node.returns is not None: - ann_sinfo = eval_type_annotation(self, node.returns) - R.func_ret_struct_info(ann_sinfo) + with self.with_dispatch_token("relax"): + with R.function(): + R.func_name(node.name) + collect_symbolic_var_from_params(self, node) + + if node.returns is not None: + ann_sinfo = eval_struct_info(self, node.returns, eval_str=True) + R.func_ret_struct_info(ann_sinfo) - with self.with_dispatch_token("relax"): self.visit(node.args) self.visit_body(node.body) @dispatch.register(token="relax", type_name="tvm_declare_function") def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: - if node.returns is None: - ret_sinfo = relax.TupleStructInfo([]) - else: - ret_sinfo = eval_type_annotation(self, node.returns) - params = [] - params_sinfo = [] - for arg in node.args.args: - if arg.annotation is None: - self.report_error(arg, "Type annotation is required for function parameters.") - param_sinfo = self.visit_tvm_annotation(arg.annotation) - params_sinfo.append(param_sinfo) - params.append(relax.Var(arg.arg, param_sinfo)) + with self.var_table.with_frame(): + collect_symbolic_var_from_params(self, node) + + if node.returns is None: + ret_sinfo = relax.TupleStructInfo([]) + else: + ret_sinfo = eval_struct_info(self, node.returns, eval_str=True) + params = [] + params_sinfo = [] + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) + params_sinfo.append(param_sinfo) + params.append(relax.Var(arg.arg, param_sinfo)) func_signature = relax.Function.create_empty(params, ret_sinfo) global_var = I.decl_function(node.name, func_signature) @@ -172,15 +192,15 @@ def visit_arguments(self: Parser, node: doc.arguments) -> None: for arg in node.args: if arg.annotation is None: self.report_error(arg, "Type annotation is required for function parameters.") - param_sinfo = self.visit_tvm_annotation(arg.annotation) + param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) param = R.arg(arg.arg, param_sinfo) self.var_table.add(arg.arg, param) @dispatch.register(token="relax", type_name="tvm_annotation") -def visit_tvm_annotation(self: Parser, node: doc.expr): - return eval_type_annotation(self, node) +def visit_tvm_annotation(self: Parser, node: doc.expr) -> StructInfo: + return eval_struct_info(self, node, eval_str=False) @dispatch.register(token="relax", type_name="With") diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index e7c24f1c4d29..1a9f14c8086d 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -43,7 +43,7 @@ Doc RelaxScriptPrinter::Print(const ObjectRef& node) { } else if (node->IsInstance()) { return VisitType(Downcast(node)); } else if (node->IsInstance()) { - return VisitExpr(Downcast(node)); + return PrintPrimExpr(Downcast(node)); } else if (node->IsInstance()) { return tir::AsTVMScriptDoc(Downcast(node), "T", false); } else if (node->IsInstance()) { @@ -59,7 +59,7 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::TupleNode* op) { size_t num_fields = op->fields.size(); if (num_fields == 0) { - return Doc::Text("R.Tuple()"); + return Doc::Text("R.tuple()"); } Doc doc; @@ -354,22 +354,30 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::ExternFuncNode* op) { return Doc::StrLiteral(op->global_symbol); } +Doc RelaxScriptPrinter::PrintPrimExpr(const PrimExpr& expr) { + Doc body = VisitExpr(expr); + if (print_symbolic_shape_as_str_ && !expr->IsInstance()) { + std::string s = body.str(); + // replace all '\"' with '\'' + // T.cast(m, "int32") ==> "T.cast(m, 'int32')" + std::replace(s.begin(), s.end(), '\"', '\''); + return Doc::Text("\"") << s << "\""; + } else { + return body; + } +} + Doc RelaxScriptPrinter::VisitExpr_(const tir::VarNode* op) { tir::Var var = GetRef(op); if (!dim_var_map_.count(var)) { dim_var_map_[var] = GetUniqueName(var->name_hint, "dim"); } - if (print_symbolic_shape_as_str_) { - Doc doc; - doc << "\"" << dim_var_map_[var] << "\""; - return doc; - } else { - if (std::none_of(symbolic_vars_.begin(), symbolic_vars_.end(), - [&var](const tir::Var& v) { return v.same_as(var); })) { - symbolic_vars_.push_back(var); - } - return dim_var_map_[var]; + if (!print_symbolic_shape_as_str_ && + std::none_of(symbolic_vars_.begin(), symbolic_vars_.end(), + [&var](const tir::Var& v) { return v.same_as(var); })) { + symbolic_vars_.push_back(var); } + return dim_var_map_[var]; } Doc RelaxScriptPrinter::VisitExpr_(const tir::IntImmNode* op) { @@ -379,8 +387,8 @@ Doc RelaxScriptPrinter::VisitExpr_(const tir::IntImmNode* op) { #define TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(OpName, OpString) \ Doc RelaxScriptPrinter::VisitExpr_(const OpName* op) { \ Doc doc; \ - doc << "(" << Print(op->a) << OpString; \ - doc << Print(op->b) << ")"; \ + doc << "(" << VisitExpr(op->a) << OpString; \ + doc << VisitExpr(op->b) << ")"; \ return doc; \ } @@ -392,13 +400,13 @@ TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(tir::FloorDivNode, " // ") Doc RelaxScriptPrinter::VisitExpr_(const tir::CastNode* op) { Doc doc; - doc << "T.cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")"; + doc << "T.cast(" << VisitExpr(op->value) << ", " << PrintDType(op->dtype) << ")"; return doc; } Doc RelaxScriptPrinter::VisitExpr_(const tir::MaxNode* op) { Doc doc; - doc << "T.max(" << Print(op->a) << ", " << Print(op->b) << ")"; + doc << "T.max(" << VisitExpr(op->a) << ", " << VisitExpr(op->b) << ")"; return doc; } @@ -585,7 +593,6 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& doc << " -> " << Print(ret_sinfo); } doc << ":" << Doc::NewLine(4); - // TODO(siyuan): Add printing of composite expression print_symbolic_shape_as_str_ = false; // Step 3: print function attr diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 08f50d601ecd..b2ee6de50313 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -330,6 +330,7 @@ class RelaxScriptPrinter : public relax::IRFunctor, Doc VisitStructInfo_(const FuncStructInfoNode* op) override; Doc GetUniqueName(std::string prefix, std::string fallback); + Doc PrintPrimExpr(const PrimExpr& expr); /*! * \brief Attribute printer which prints the attributes as kwargs in a call. diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index d831f9eedc0e..9db7cea6725d 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -84,6 +84,7 @@ ShapeStructInfo::ShapeStructInfo(Array values, Span span) { ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { ObjectPtr n = make_object(); + CHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " << ndim; n->ndim = ndim; n->span = span; data_ = std::move(n); @@ -130,6 +131,7 @@ TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) { TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Span span) { ObjectPtr n = make_object(); + CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << ndim; n->ndim = ndim; n->dtype = dtype; n->span = span; diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 1ece5cd9deff..e76de529b060 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -49,30 +49,6 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) vid->name_hint = name; }); -////////////////////////////// Tensor Type ////////////////////////////// - -using tvm::relax::TensorStructInfo; -using tvm::relax::TupleStructInfo; - -TensorStructInfo Tensor(Optional> shape, DataType dtype, int ndim) { - using namespace tvm::relax; - ICHECK_GE(ndim, -1) << "ndim must be >= -1, but got " << ndim; - if (shape.defined() && ndim >= 0) { - CHECK_EQ(shape.value().size(), ndim) - << "The dimension of the given shape is mismatched with the given `ndim`"; - } else if (shape.defined()) { - ndim = shape.value().size(); - } - if (shape.defined()) { - ShapeExpr shape_expr(shape.value()); - return TensorStructInfo(shape_expr, dtype); - } else { - return TensorStructInfo(dtype, ndim); - } -} - -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Tensor").set_body_typed(Tensor); - /////////////////////////////// Function //////////////////////////////// FunctionFrame Function() { @@ -253,55 +229,6 @@ TVM_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); TVM_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); TVM_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); -//////////////////////// Symbolic Shape Rewriter //////////////////////// - -using tvm::relax::Expr; -class SymbolicShapeRewriter : public tvm::relax::StructInfoMutator { - public: - explicit SymbolicShapeRewriter(Map var_table) - : var_table_(std::move(var_table)) {} - - Array undefined_vars_; - - private: - Expr VisitStructInfoExprField(const Expr& expr) { - if (const auto* shape_expr = expr.as()) { - Array new_shape; - bool changed = false; - for (const tvm::PrimExpr& s : shape_expr->values) { - if (const auto* var = s.as()) { - auto it = var_table_.find(var->name_hint); - if (it != var_table_.end()) { - new_shape.push_back((*it).second); - changed = true; - } else { - undefined_vars_.push_back(GetRef(var)); - var_table_.Set(var->name_hint, GetRef(var)); - new_shape.push_back(s); - } - } else { - // TODO(siyuan, ruihang): confirm and use VisitPrimExpr to recursive rewrite. - new_shape.push_back(s); - } - } - if (changed) { - return tvm::relax::ShapeExpr(new_shape); - } - } - return expr; - } - - private: - Map var_table_; -}; - -TVM_REGISTER_GLOBAL("script.ir_builder.relax.RewriteSymbolicShape") - .set_body_typed([](tvm::relax::StructInfo struct_info, Map var_table) { - SymbolicShapeRewriter rewriter(var_table); - tvm::relax::StructInfo rewritten_info = rewriter(std::move(struct_info)); - return Array{rewritten_info, rewriter.undefined_vars_}; - }); - } // namespace relax } // namespace ir_builder } // namespace script diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 726a50d83301..300911703dad 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -352,7 +352,7 @@ def test_call_packed(): @R.function def f( x: R.Tensor((32, "m"), "float32"), - y: R.Tensor(("m"), "float32"), + y: R.Tensor(("m",), "float32"), r: R.Tensor(dtype="int64"), ) -> R.Object: m = T.var("int64") diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 1721f85ed40e..b654c581f153 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -16,12 +16,12 @@ # under the License. import tvm.script - -from tvm import tir, relax +import tvm.testing +from tvm import relax from tvm.ir import assert_structural_equal - -from tvm.script import tir as T, relax as R -from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode +from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode +from tvm.script import relax as R +from tvm.script import tir as T def test_const_shape_arg(): @@ -406,3 +406,8 @@ def main( expected = Expected after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) assert_structural_equal(after, expected) + + +if __name__ == "__main__": + test_static_fn_check() + tvm.testing.main() diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 07aa4a22aaf1..7825380fcfa2 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -39,7 +39,7 @@ def test_annotations(): @R.function def f( x: R.Tensor((32, "m"), "float32"), - y: R.Tensor(("m"), "float32"), + y: R.Tensor(("m",), "float32"), r: R.Tensor(dtype="int64"), ) -> R.Object: m = T.var("int64") @@ -302,7 +302,7 @@ def f(x: R.Tensor, y: R.Tensor((32,), "float32")): def test_tuplegetitem(): @R.function def f(x: R.Tensor, y: R.Tensor): - t1 = R.Tuple((x, y)) + t1 = R.tuple(x, y) t2 = (x, y) a = t1[0] b = R.TupleGetItem(t2, 1) diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index 6cdb53c908a0..82e0c06589ef 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -40,7 +40,7 @@ def check_roundtrip(f_pre): def test_annotations(): @R.function - def foo(x: R.Tensor((32, "m"), "float32"), y: R.Tensor(("m"), "float32")) -> R.Tensor: + def foo(x: R.Tensor((32, "m"), "float32"), y: R.Tensor(("m",), "float32")) -> R.Tensor: m = T.var("int64") z: R.Tensor((32, m), "float32") = R.multiply(x, y) w = R.multiply(z, z) @@ -54,7 +54,7 @@ def foo(x: R.Tensor((32, "m"), "float32"), y: R.Tensor(("m"), "float32")) -> R.T def test_ndim_annotations(): @R.function def foo( - x: R.Tensor((2, 3, 5), "float32", ndim=3), + x: R.Tensor((2, 3, 5), "float32"), y: R.Tensor(dtype="float32", ndim=-1), z: R.Tensor(dtype="float32", ndim=2), ): @@ -103,7 +103,7 @@ def test_tuplegetitem(): def foo(x: R.Tensor(ndim=2)): y = R.add(x, x) z = R.multiply(y, x) - t = R.Tuple((y, z)) + t = (y, z) a = R.TupleGetItem(t, 0) b = R.TupleGetItem(t, 1) c = R.multiply(a, b) @@ -344,7 +344,6 @@ def h( check_roundtrip(my_module) -@pytest.mark.skip("Need to fix string ast expr") def test_tir_max(): @R.function def tir_max(x: R.Tensor(("m", "n"), "float32")): @@ -355,7 +354,6 @@ def tir_max(x: R.Tensor(("m", "n"), "float32")): check_roundtrip(tir_max) -@pytest.mark.skip("Need to fix string ast expr") def test_tir_cast(): @R.function def tir_cast(x: R.Tensor(("m",), "float32")): @@ -413,7 +411,7 @@ def global_func_2( @R.function def local_func_2( - y: R.Tensor(("m", "n"), "float32") + y: R.Tensor((m, n), "float32") ) -> R.Callable((R.Tensor((m, n), "float32"),), R.Tensor((m, n), "float32")): @R.function def local_func_3( diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index 9c352a8e427f..9037b907b445 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -259,7 +259,7 @@ def test_normalize_if_condition(): def expected( cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") ) -> R.Tensor(dtype="float32", ndim=1): - c = R.TupleGetItem(R.Tuple(cond), 0) + c = R.TupleGetItem(R.tuple(cond), 0) if c: gv = R.add(x, x) y = gv diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index b154e85b8cc8..f1994c6b3133 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -33,8 +33,8 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) with R.function(): R.func_name("foo") R.func_attr({"Primitive": 1}) - x = R.arg("x", R.tensor((128, 128), "float32")) - R.func_ret_struct_info(R.tensor(dtype="float32", ndim=2)) + x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) + R.func_ret_struct_info(relax.TensorStructInfo(dtype="float32", ndim=2)) out = R.emit(R.call_tir("extern_func", x, (128, 128), dtype="float32")) IRBuilder.name("out", out) R.func_ret_value(out) @@ -67,12 +67,12 @@ def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): with IRBuilder() as ir_builder: with R.function(): R.func_name("foo") - x = R.arg("x", R.tensor(ndim=-1, dtype="float32")) - y = R.arg("y", R.tensor(ndim=-1, dtype="float32")) + x = R.arg("x", relax.TensorStructInfo(ndim=-1, dtype="float32")) + y = R.arg("y", relax.TensorStructInfo(ndim=-1, dtype="float32")) m = tir.Var("m", dtype="int64") n = tir.Var("n", dtype="int64") - _ = R.emit_match_cast(x, R.tensor((m,), "float32")) - y1 = R.emit_match_cast(y, R.tensor((n,), "float32")) + _ = R.emit_match_cast(x, relax.TensorStructInfo((m,), "float32")) + y1 = R.emit_match_cast(y, relax.TensorStructInfo((n,), "float32")) IRBuilder.name("y1", y1) R.func_ret_value(relax.ShapeExpr([m, n * 2])) func = ir_builder.get() @@ -84,8 +84,8 @@ def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): n = tir.Var("n", dtype="int64") bb = relax.BlockBuilder() with bb.function("foo", (x, y)): - _ = bb.match_cast(x, R.tensor((m,), "float32")) - y1 = bb.match_cast(y, R.tensor((n,), "float32")) + _ = bb.match_cast(x, relax.TensorStructInfo((m,), "float32")) + y1 = bb.match_cast(y, relax.TensorStructInfo((n,), "float32")) bb.emit_func_output(relax.ShapeExpr([m, n * 2])) mod = bb.get() @@ -107,7 +107,7 @@ def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): with IRBuilder() as ir_builder: with R.function(): R.func_name("foo") - x = R.arg("x", R.tensor((128, 128), "float32")) + x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) with R.dataflow() as df: lv0 = R.emit(R.call_tir("extern_func", x, (128, 128), dtype="float32")) IRBuilder.name("lv0", lv0) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index f81b6fa9538e..61eef1758dd7 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -19,6 +19,7 @@ import pytest import tvm +import tvm.script import tvm.testing from tvm import IRModule, relax, tir from tvm.relax import DynTensorType @@ -467,7 +468,7 @@ def test_annotation(): @R.function def foo( x: R.Tensor((32, "m"), "float32"), - y: R.Tensor(("m"), "float32"), + y: R.Tensor(("m",), "float32"), r: R.Tensor(dtype="int64"), ) -> R.Object: m = T.var("int64", "m") @@ -773,7 +774,7 @@ def foo(x: R.Tensor): def test_empty_tuple(): @R.function def foo(x: R.Tuple()): - y: R.Tuple() = R.Tuple() + y: R.Tuple() = R.tuple() return y x = relax.Var("x", relax.TupleStructInfo([])) @@ -785,6 +786,68 @@ def foo(x: R.Tuple()): _check(foo, bb.get()["foo"]) +def test_symbolic_shape_computing(): + # Tensor Case 1 + @R.function + def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")): + z = R.add(x, y) + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.TensorStructInfo([m + 1], "float32")) + y = relax.Var("y", relax.TensorStructInfo([m, 1], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + z = bb.emit(relax.op.add(x, y)) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + # Tensor Case 2 + @R.function + def bar( + x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") + ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"): + m = T.var("int64") + z = R.call_tir("test_intrin", (x, y), (T.max(m, 20) + 1,), dtype="float32") + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.TensorStructInfo([m], "float32")) + y = relax.Var("y", relax.TensorStructInfo([tir.max(m, 20)], "float32")) + bb = relax.BlockBuilder() + with bb.function("bar", (x, y)): + z = bb.emit(relax.call_tir("test_intrin", (x, y), (tir.max(m, 20) + 1,), dtype="float32")) + bb.emit_func_output(z) + + _check(bar, bb.get()["bar"]) + + # Shape Case + @R.function + def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): + m = T.var("int64") + z = R.call_tir("test_intrin", y, (m * 2,), dtype="float32") + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.ShapeStructInfo([m])) + y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) + bb = relax.BlockBuilder() + with bb.function("baz", (x, y)): + z = bb.emit(relax.call_tir("test_intrin", (y), (m * 2,), dtype="float32")) + bb.emit_func_output(z) + + _check(baz, bb.get()["baz"]) + + # Error Case + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor(("m + 1", "m * 2"), "float32")): # name 'm' is not defined + z = R.add(x, x) + return z + + @pytest.mark.skip(reason="potential upstream Metadata changes.") def test_meta(): metadata = tvm.ir.load_json(