Skip to content

Commit

Permalink
Add metadata section, support constant and metadata in parser & print…
Browse files Browse the repository at this point in the history
…er (apache#76)

* [CI] Set up CI; format and lint relax code to pass CI (apache#72)

* init

* fix lint

* update task_lint

* more lint

* more lint

* lint

* jenkinsfile

* jenkinsfile

* run relax only tests

* python3.7 for pytest

* point to personal ci-cpu docker

* docker pull

* test

* fix cmake config

* update

* update

* rebase

* rebase

* AutoTIR integration (apache#58)

* [WIP] Basic task extraction mechanism is implemented.

* [WIP] For gradual integration with Relay pipeline, meta_schedule/integration.py is created for relax to avoid potential conflict.

* support tir tuning and injection mode

* Add target field for Relax Extracted Task

* 1. Create relax namespace/tvm objects/... for metaschedule to preserve relay support. 2. Promote target field from Optional<Target> to Target

* Support ApplyHistoryBest

* Reflect feedback from Yuchen

* minor improvement and fix linter issue

* add ASF header

* Reorganize file structure

* fix lint errors

* remove the import-outside-toplevel

* Reflect comments

* remove redundant comment

* As per discussion w/ Yuchen, ApplyHistoryBest is introduced as a Relax transformation pass.

* remove redundant print msg

* fix lint

* reflect comments

* Yuchen's change

* relax ConstantNode in parser and printer

* Add constant data in the metasection

* rebase

* Support ir_module(metadata=json_str)

* update test case

* remove print info

* Update tests

* clang-format

* pylint

* fix ci

* Save a copy of metadata in RelaxTransformer

* Fix comments

* fix comments

Co-authored-by: Yuchen Jin <yuchenj@cs.washington.edu>
Co-authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com>
  • Loading branch information
3 people authored and junrushao committed Feb 9, 2023
1 parent a12470a commit 5c2ad96
Show file tree
Hide file tree
Showing 15 changed files with 576 additions and 160 deletions.
8 changes: 7 additions & 1 deletion include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class IRModuleNode : public Object {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}

/*!
* \brief Get the metadata attributes.
* \returns The additional meta-data attributes
*/
DictAttrs GetAttrs() const { return attrs; }

/*!
* \brief Check whether the module has an non-zero integer attr.
*
Expand Down Expand Up @@ -353,7 +359,7 @@ class IRModule : public ObjectRef {
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module.
* \param map The module source map.
* \param attrs The module attributes.
* \param attrs The module meta-data attributes.
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
Expand Down
1 change: 0 additions & 1 deletion include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ class VarNode : public ExprNode {
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) &&
// Do we use the analysis information in equality?
equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_);
}

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class ConstantNode : public ExprNode {
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
v->Visit("shape_", &shape_);
}

bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
The left operand.
rhs : Object
The left operand.
The right operand.
map_free_vars : bool
Whether or not shall we map free vars that does
Expand Down
29 changes: 26 additions & 3 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""IRModule that holds the functions and type definitions."""
import ast

import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Scriptable

from ..ir.function import BaseFunc
from . import _ffi_api
from . import expr as _expr
from ..ir.function import BaseFunc
from . import type as _ty
from .base import Node

Expand All @@ -38,7 +40,7 @@ class IRModule(Node, Scriptable):
Map of global var to BaseFunc
"""

def __init__(self, functions=None, type_definitions=None):
def __init__(self, functions=None, type_definitions=None, attrs=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
Expand All @@ -61,7 +63,17 @@ def __init__(self, functions=None, type_definitions=None):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)

attrs = None if not attrs else attrs
if attrs is not None:
attrs = ast.literal_eval(str(attrs))
attrs = tvm.ir.make_node("DictAttrs", **attrs)
self.__init_handle_by_constructor__(
_ffi_api.IRModule,
functions,
type_definitions,
attrs,
)

def __setitem__(self, var, val):
"""Add a mapping to the module.
Expand Down Expand Up @@ -268,6 +280,17 @@ def get_attr(self, attr_key):

return _ffi_api.Module_GetAttr(self, attr_key)

def get_attrs(self):
"""Get the meta_data attributes.
Returns
-------
meta_data : DictAttrs
meta_data attributes
"""

return _ffi_api.Module_GetAttrs(self)

def with_attr(self, attr_key, attr_value):
"""Copy the IRModule and add an attribute to it.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Call = relay.Call
If = relay.If
const = relay.const
Constant = relay.Constant


@tvm._ffi.register_object("relax.expr.ShapeExpr")
Expand Down
56 changes: 55 additions & 1 deletion python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,66 @@
from tvm.tir import PrimFunc
from tvm import IRModule

# Simply extracts tir PrimFuncs from the input IRModule

def tir_partitioner(mod: IRModule) -> List[IRModule]:
"""Extracts tir PrimFuncs from the input IRModule.
Parameters
----------
mod : IRModule
The input IRModule.
Returns
-------
output : List[IRModule]
The result tir PrimFuncs.
"""
partitions = []
for gvar in mod.get_global_vars():
if isinstance(mod[gvar], PrimFunc):
tir_mod = IRModule({})
tir_mod[gvar] = mod[gvar]
partitions.append(tir_mod)
return partitions


def metadata_partitioner(rx_txt: str) -> List[str]:
"""Extract Relax program and metadata section.
Parameters
----------
rx_txt : str
The input relax text.
Returns
-------
output : List[str]
The result list of partitioned text, the first element
is the relax program, and the second is metadata section.
"""
partitions = []
left_curly = 0
meta_start = 0
meta_end = 0
for i, char in enumerate(rx_txt):
if i < 0:
raise ValueError("The program is invalid.")
if char == "{":
if meta_start == 0:
meta_start = i
left_curly += 1
elif char == "}":
left_curly -= 1
if left_curly == 0:
meta_end = i + 1
break

if meta_end == 0:
raise ValueError("The metadata section was not found.")
metadata = rx_txt[meta_start:meta_end]
rx_program = rx_txt[meta_end:-1]

partitions.append(rx_program)
partitions.append(metadata)

return partitions
29 changes: 22 additions & 7 deletions python/tvm/script/relax/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,47 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script Interface for Relax Functions"""
# pylint: disable=import-outside-toplevel

import inspect
from typing import Callable
import functools

from tvm.relax import Function

from .parser import from_source


def function(input_func: Callable) -> Function:
def function(input_func=None, metadata=None) -> Function:
"""Decorate a Python function as a Relax function in TVM script.
Parameters
----------
input_func : Callable
The function to be parsed.
metadata : Optional[Union[str, DictAttrs]]
The meta_data attributes to be parsed.
Returns
-------
output : Function
The parsed Relax Function.
"""
if inspect.isfunction(input_func):
result = from_source(input_func)
result.__name__ = input_func.__name__
result.__qualname__ = input_func.__qualname__
return result
if metadata is not None:
from .parser import RelaxTransformer as _RelaxTransformer

_RelaxTransformer.update_meta(metadata)

if input_func is None:
return functools.partial(function, metadata=metadata)

def _function(input_func: Callable) -> Function:
if inspect.isfunction(input_func):
result = from_source(input_func)
result.__name__ = input_func.__name__
result.__qualname__ = input_func.__qualname__
return result
raise TypeError("Only function definitions are supported.")

raise TypeError("Only function definitions are supported.")
return _function(input_func)
Loading

0 comments on commit 5c2ad96

Please sign in to comment.