Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/api/python/ir.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@ tvm.ir
:members:
:imported-members:
:autosummary:


tvm.transform
-------------
.. automodule:: tvm.transform
:members:
:imported-members:
:autosummary:
2 changes: 1 addition & 1 deletion docs/dev/convert_layout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r

# Convert the layout to NCHW
# RemoveUnunsedFunctions is used to clean up the graph.
seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout('NCHW')])
with relay.transform.PassContext(opt_level=3):
mod = seq(mod)
Expand Down
4 changes: 2 additions & 2 deletions docs/dev/relay_pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes.
func = relay.Function([x], z2)

# Customize the optimization pipeline.
seq = _transform.Sequential([
seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
Expand All @@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for

.. code:: python

seq = _transform.Sequential([
seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.PrintIR(),
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass(

/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
* \param show_meta_data Whether should we show meta data.
* \return The pass.
*/
TVM_DLL Pass PrintIR(std::string header);
TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);

} // namespace transform
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _update_global_key(item, _):
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"),
"relay.Sequential": _rename("transform.Sequential"),
# TIR
"Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"),
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,19 @@ def create_module_pass(pass_arg):
return create_module_pass


def PrintIR(header):
def PrintIR(header="", show_meta_data=False):
"""A special trace pass that prints the header and IR.

Parameters
----------
header : str
The header to be displayed along with the dump.

show_meta_data : bool
A boolean flag to indicate if meta data should be printed.

Returns
--------
The pass
"""
return _ffi_transform_api.PrintIR(header)
return _ffi_transform_api.PrintIR(header, show_meta_data)
11 changes: 0 additions & 11 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,9 @@
# Scope builder
ScopeBuilder = scope_builder.ScopeBuilder

module_pass = transform.module_pass
function_pass = transform.function_pass

# Parser
fromtext = parser.fromtext

# Param Serialization
save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict

# Pass manager
PassInfo = transform.PassInfo
PassContext = transform.PassContext
Pass = transform.Pass
ModulePass = transform.ModulePass
FunctionPass = transform.FunctionPass
Sequential = transform.Sequential
8 changes: 4 additions & 4 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ def optimize(self):
opt_mod : tvm.IRModule
The optimized module.
"""
seq = transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType()])
seq = tvm.transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType()])
return seq(self.mod)

def _make_executor(self, expr=None):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/qnn/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def @main(%quantized_data: Tensor[(200), int32]) -> Tensor[(200), int8] {

Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that canonicalizes QNN ops to Relay ops.
"""

Expand Down Expand Up @@ -108,7 +108,7 @@ def Legalize():

Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that legalizes QNN ops.
"""

Expand Down
33 changes: 18 additions & 15 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit."""
import tvm.ir
import tvm
from tvm.runtime import Object

from . import _quantize
Expand Down Expand Up @@ -240,7 +241,7 @@ def partition():

Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizePartition()
Expand All @@ -253,7 +254,7 @@ def annotate():

Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for quantization annotation.
"""
return _quantize.QuantizeAnnotate()
Expand All @@ -267,7 +268,7 @@ def realize():

Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for quantization realization.
"""
return _quantize.QuantizeRealize()
Expand Down Expand Up @@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None):
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
optimize = tvm.transform.Sequential(
[_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])

if params:
mod['main'] = _bind_params(mod['main'], params)
Expand Down Expand Up @@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None):
"""
mod = prerequisite_optimize(mod, params)

calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1,
name="QuantizeCalibrate")
calibrate_pass = tvm.transform.module_pass(
calibrate(dataset), opt_level=1,
name="QuantizeCalibrate")
quant_passes = [partition(),
annotate(),
calibrate_pass]
if not current_qconfig().do_simulation:
quant_passes.append(realize())
quant_passes.append(_transform.FoldConstant())
quantize_seq = _transform.Sequential(quant_passes)
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
quantize_seq = tvm.transform.Sequential(quant_passes)
with tvm.transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
with quantize_context():
mod = quantize_seq(mod)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ..transform import gradient

def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def optimize(self, prog: Expr):

# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions)
opts = relay.transform.Sequential([relay.transform.SimplifyInference(),
relay.transform.FuseOps(fuse_opt_level=0)])
opts = tvm.transform.Sequential([relay.transform.SimplifyInference(),
relay.transform.FuseOps(fuse_opt_level=0)])
mod = opts(mod)
optimized = mod['main']
return optimized if isinstance(unwrapped, Function) else optimized.body
Expand Down
Loading