Skip to content

Commit

Permalink
Migrate passes to Pass Infra (apache#37)
Browse files Browse the repository at this point in the history
* Migrate relax passes -> Pass infra.

* Update.

* Add docs and update tests.

* Rebase and change namespace.

* Address comments.
  • Loading branch information
YuchenJin authored and yongwww committed Jun 12, 2022
1 parent 3f001ec commit a081c76
Show file tree
Hide file tree
Showing 13 changed files with 508 additions and 173 deletions.
51 changes: 51 additions & 0 deletions include/tvm/relax/backend.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/backend.h
* \brief Relax backend specific transformation passes.
*/
#ifndef TVM_RELAX_BACKEND_H_
#define TVM_RELAX_BACKEND_H_

#include <tvm/relax/transform.h>

namespace tvm {
namespace relax {
namespace transform {

/*!
* \brief Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics.
*
* \return The Pass.
*/
TVM_DLL Pass VMMemoryLower();

/*!
* \brief Lower the shape expression in relax to VM shape heap and TIR functions.
*
* \return The Pass.
*/
TVM_DLL Pass VMShapeLower();

} // namespace transform
} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_BACKEND_H_
10 changes: 5 additions & 5 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class Var : public Expr {
TVM_DLL explicit Var(Id vid, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
};

/*! \brief A sub-type of the variable node used to mark dataflow variables from
Expand Down Expand Up @@ -414,11 +414,11 @@ class FunctionNode : public BaseFuncNode {
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};

class Function : public Expr {
class Function : public BaseFunc {
public:
TVM_DLL explicit Function(runtime::Optional<GlobalVar> name, Array<Var> params, Expr body,
Type ret_type, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode);
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
};

Expand All @@ -445,10 +445,10 @@ class ExternFuncNode : public BaseFuncNode {
TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode);
};

class ExternFunc : public Expr {
class ExternFunc : public BaseFunc {
public:
TVM_DLL ExternFunc(String global_symbol, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, Expr, ExternFuncNode);
TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode);
};

Expand Down
85 changes: 85 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/transform.h
* \brief Relax specific transformation passes.
*/
#ifndef TVM_RELAX_TRANSFORM_H_
#define TVM_RELAX_TRANSFORM_H_

#include <tvm/ir/transform.h>
#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {
namespace transform {

using Pass = tvm::transform::Pass;
using PassInfo = tvm::transform::PassInfo;
using PassContext = tvm::transform::PassContext;
using Function = tvm::relax::Function;

/*!
* \brief Create a function pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
*
* \return The created function pass.
*/
TVM_DLL Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required);

/*!
* \brief Perform fused multiply add rewriting in dataflow blocks.
*
* \return The Pass.
*/
TVM_DLL Pass FMARewrite();

/*!
* \brief Transform all dataflow structure to non-dataflow version.
*
* \return The Pass.
*/
TVM_DLL Pass ToNonDataflow();

/*!
* \brief Perform explicit tensor allocation for call_dps.
*
* \return The Pass.
*/
TVM_DLL Pass CallDPSRewrite();

/*!
* \brief Transform Relax IR to A-normal form.
*
* \return The Pass.
*/
TVM_DLL Pass ToANF();

} // namespace transform
} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_TRANSFORM_H_
75 changes: 41 additions & 34 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,64 +16,71 @@
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
from tvm import IRModule
import tvm.ir
from . import _ffi_api

@tvm._ffi.register_object("relax.FunctionPass")
class FunctionPass(tvm.ir.transform.Pass):
"""A pass that works on each tvm.relax.Function in a module. A function
pass class should be created through `function_pass`.
"""

def fma_rewrite(expr):
def FMARewrite() -> tvm.transform.Pass:
"""Perform fused multiply add rewriting in dataflow blocks.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret: tvm.transform.Pass
"""
return _ffi_api.fma_rewrite(expr)
return _ffi_api.FMARewrite()


def to_non_dataflow(mod: IRModule) -> IRModule:
def ToNonDataflow() -> tvm.transform.Pass:
"""Transform all dataflow structure to non-dataflow version.
Parameters
----------
mod : tvm.IRModule
The input module.
Returns
-------
ret: tvm.transform.Pass
"""
return _ffi_api.to_non_dataflow(mod)
return _ffi_api.ToNonDataflow()


def call_dps_rewrite(mod: IRModule) -> IRModule:
def CallDPSRewrite() -> tvm.transform.Pass:
"""Perform explicit tensor allocation for call_dps.
Parameters
----------
mod : tvm.IRModule
The input module.
Returns
-------
ret: tvm.transform.Pass
"""
return _ffi_api.call_dps_rewrite(mod)
return _ffi_api.CallDPSRewrite()


def vm_memory_lower(mod: IRModule) -> IRModule:
def VMMemoryLower() -> tvm.transform.Pass:
"""Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics.
Parameters
----------
mod : tvm.IRModule
The input module.
Returns
-------
ret: tvm.transform.Pass
"""
return _ffi_api.vm_memory_lower(mod)
return _ffi_api.VMMemoryLower()


def vm_shape_lower(mod: IRModule) -> IRModule:
"""Lower the shape expression in relax to VM shape heap and TIR functions.
def VMShapeLower() -> tvm.transform.Pass:
"""Lower the shape expressions in relax to VM shape heap manipulations and generate related
TIR functions to do shape calculations.
Parameters
----------
mod : tvm.IRModule
The input module.
Returns
-------
ret: tvm.transform.Pass
"""
return _ffi_api.vm_shape_lower(mod)
return _ffi_api.VMShapeLower()


def ToANF() -> tvm.transform.Pass:
"""Transforming Relax IR to A-normal form.
def to_anf(mod: IRModule):
return _ffi_api.to_anf(mod)
Returns
-------
ret: tvm.transform.Pass
"""
return _ffi_api.ToANF()
14 changes: 8 additions & 6 deletions python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ def build(mod: tvm.IRModule,
lib: tvm.runtime.Module
A runtime module that contains generated code.
"""
new_mod = transform.to_non_dataflow(mod)
new_mod = transform.call_dps_rewrite(new_mod)
new_mod = transform.vm_memory_lower(new_mod)
new_mod = transform.vm_shape_lower(new_mod)
passes = [relax.transform.ToNonDataflow()]
passes.append(relax.transform.CallDPSRewrite())
passes.append(relax.transform.VMMemoryLower())
passes.append(relax.transform.VMShapeLower())
seq = tvm.transform.Sequential(passes)
new_mod = seq(mod)

# split primfunc and relax function
rx_mod, tir_mod = _split_tir_relax(new_mod)
Expand All @@ -189,5 +191,5 @@ def _split_tir_relax(mod: tvm.IRModule) -> Tuple[tvm.IRModule, tvm.IRModule]:
elif isinstance(mod[gv], relax.Function):
rx_mod[gv] = mod[gv]
else:
raise ValueError("An IRModule should contain contain relax function and TIR primfunc.")
return rx_mod, tir_mod
raise ValueError("An IRModule should contain relax function and/or TIR primfunc.")
return rx_mod, tir_mod
50 changes: 21 additions & 29 deletions src/relax/backend/vm/vm_memory_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
*/
/*!
* \file src/relax/backend/vm/vm_memory_lower.cc
* \brief
* \brief Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics.
*/
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/backend.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>
Expand All @@ -29,33 +30,18 @@

namespace tvm {
namespace relax {
namespace vm {

// ==================
// MemLowerMutator
// Lower the relax.builtin.alloc_tensor op to VM builtin functions.
// Example:
// x = relax.builtin.alloc_tensor((m, n))
// -->
// gv0 = relax.call_packed("relax.vm.builtin.alloc_storage", (m * n), relax.attrs.AllocStorageAttrs)
// gv1 = relax.call_packed("relax.vm.builtin.alloc_tensor", gv0, (m, n), relax.attrs.AllocTensorAttrs)
// gv0 = relax.call_packed("relax.vm.builtin.alloc_storage", (m * n), relax.attrs.AllocStorageAttrs)
// gv1 = relax.call_packed("relax.vm.builtin.alloc_tensor", gv0, (m, n),
// relax.attrs.AllocTensorAttrs)

class VMMemLowerMutator : public ExprMutator {
public:
explicit VMMemLowerMutator(IRModule mod) { mod_ = mod; }

IRModule Lower() {
IRModule ret_mod = IRModule();
for (auto& p : mod_->functions) {
Expr func = p.second;
if (p.second->IsInstance<FunctionNode>()) {
func = this->VisitExpr(p.second);
}
ret_mod->Add(p.first, Downcast<BaseFunc>(func));
}
return ret_mod;
}

Expr ComputeStorageSize(const Expr& shape, const Type& type) const {
DynTensorType tensor_type = Downcast<DynTensorType>(type);
DataType dtype = DataType(tensor_type->dtype);
Expand Down Expand Up @@ -101,27 +87,33 @@ class VMMemLowerMutator : public ExprMutator {
storage_attr->dtype = DataType::Float(32);
storage_attr->device_type = 1;

Var storage = builder_->Emit(Call(vm_alloc_storage_op, {storage_size}, Attrs(storage_attr)), "storage");
Var storage =
builder_->Emit(Call(vm_alloc_storage_op, {storage_size}, Attrs(storage_attr)), "storage");
auto tensor_attr = make_object<AllocTensorAttrs>();
tensor_attr->offset = 0;
tensor_attr->dtype = DataType::Float(32);
Expr shape = call->args[0];
Var tensor = builder_->Emit(Call(vm_alloc_tensor_op, {storage, shape}, Attrs(tensor_attr)), "tensor");
Var tensor =
builder_->Emit(Call(vm_alloc_tensor_op, {storage, shape}, Attrs(tensor_attr)), "tensor");
return tensor;
}

return GetRef<Expr>(call);
}

private:
IRModule mod_;
};

TVM_REGISTER_GLOBAL("relax.transform.vm_memory_lower")
.set_body_typed([](IRModule mod) {
return VMMemLowerMutator(mod).Lower();
});
Expr VMMemLower(const Expr& e) { return VMMemLowerMutator().VisitExpr(e); }

namespace transform {

Pass VMMemoryLower() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(VMMemLower(f)); };
return CreateFunctionPass(pass_func, 0, "VMMemoryLower", {});
}

TVM_REGISTER_GLOBAL("relax.transform.VMMemoryLower").set_body_typed(VMMemoryLower);

} // namespace vm
} // namespace transform
} // namespace relax
} // namespace tvm
Loading

0 comments on commit a081c76

Please sign in to comment.