Skip to content

Commit

Permalink
[Relax] Eager free original weights in transform_params (#16674)
Browse files Browse the repository at this point in the history
* [Relax] Eager free original weights in transform_params

* address comments
  • Loading branch information
vinx13 authored Mar 6, 2024
1 parent 22dd8d8 commit a0f57a0
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 4 deletions.
48 changes: 48 additions & 0 deletions src/relax/transform/lift_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
namespace tvm {
namespace relax {

constexpr const char* kLiftTransformConsumeParams = "relax.lift_transform_params.consume_params";
TVM_REGISTER_PASS_CONFIG_OPTION(kLiftTransformConsumeParams, Bool);

namespace {

struct CollectInfo {
Expand Down Expand Up @@ -449,6 +452,48 @@ inline bool ends_with(const std::string& value, const std::string& ending) {
std::equal(ending.rbegin(), ending.rend(), value.rbegin());
}

/*!
* \brief A mutator to rewrite the transform_params functions to release the original weight after
* use. This is done by using builtin.tuple_reset_item to reset the bundled weight tuple. It
* requires `BundleModelParams` to be called before this mutator.
*/
class ConsumeBundledParams : public ExprMutator {
public:
void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* tuple_get_item) final {
static const auto& call_pure_packed = Op::Get("relax.call_pure_packed");
static const auto& builtin_tuple_reset_item = ExternFunc("vm.builtin.tuple_reset_item");
if (tuple_get_item->tuple.same_as(params_)) {
if (auto it = param_remap_.find(tuple_get_item->index); it != param_remap_.end()) {
ReEmitBinding(binding, it->second);
return;
}
ExprMutator::VisitBinding_(binding, tuple_get_item);
auto new_var = VisitExpr(binding->var);
param_remap_[tuple_get_item->index] = new_var;
builder_->Emit(
Call(call_pure_packed,
{builtin_tuple_reset_item, tuple_get_item->tuple, PrimValue(tuple_get_item->index)},
tvm::Attrs(), {TupleStructInfo(Array<StructInfo>{})}));
} else {
ExprMutator::VisitBinding_(binding, tuple_get_item);
}
}

Expr VisitExpr_(const FunctionNode* func) final {
auto opt_num_input = func->GetAttr<Integer>(attr::kNumInput);
ICHECK(opt_num_input.defined());
auto num_input = opt_num_input.value()->value;
ICHECK_EQ(func->params.size(), num_input + 1);
params_ = func->params.back();
ICHECK(params_->struct_info_.as<TupleStructInfoNode>());
return ExprMutator::VisitExpr_(func);
}

private:
Var params_;
std::unordered_map<int, Expr> param_remap_;
};

} // namespace

namespace transform {
Expand Down Expand Up @@ -498,6 +543,9 @@ Pass LiftTransformParams() {
if (ends_with(func_name, "transform_params")) {
func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
func = BundleModelParams(func);
if (pc->GetConfig<Bool>(kLiftTransformConsumeParams).value_or(Bool(false))) {
func = Downcast<Function>(ConsumeBundledParams()(func));
}
to_add[gvar] = func;
} else if (ends_with(func_name, "_runtime")) {
std::string name(func_name.begin(), func_name.end() - sizeof("_runtime") + 1);
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/relax_vm/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.invoke_debug_func")
TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem")
.set_body_typed([](runtime::Array<ObjectRef> arr, int64_t index) { return arr[index]; });

TVM_REGISTER_GLOBAL("vm.builtin.tuple_reset_item")
.set_body_typed([](runtime::Array<ObjectRef> arr, int64_t index) {
arr.Set(index, ObjectRef(nullptr));
});

TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Array<ObjectRef> arr;
for (int i = 0; i < args.num_args; ++i) {
Expand Down
98 changes: 94 additions & 4 deletions tests/python/relax/test_transform_lift_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import tvm.topi.testing


def test_basic():
@pytest.mark.parametrize("consume_params", [True, False])
def test_basic(consume_params):
@tvm.script.ir_module
class Before:
@T.prim_func
Expand Down Expand Up @@ -132,12 +133,101 @@ def main_transform_params(
R.output(gv)
return gv

@tvm.script.ir_module
class ExpectedConsumeParams:
@R.function
def main(
x: R.Tensor((1, 3, 224, 224), dtype="float32"),
w2: R.Tensor((16, 16, 3, 3), dtype="float32"),
w1_transformed: R.Tensor((16, 3, 3, 3), dtype="float32"),
) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d(
x,
w1_transformed,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
groups=1,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="NCHW",
out_dtype="void",
)
conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d(
conv1,
w2,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
groups=1,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="NCHW",
out_dtype="void",
)
R.output(conv2)
return conv2

@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]

@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
)
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
):
R.func_attr({"num_input": 0})
cls = ExpectedConsumeParams
with R.dataflow():
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
_1: R.Tuple = R.call_pure_packed(
"vm.builtin.tuple_reset_item",
params,
R.prim_value(T.int32(0)),
sinfo_args=(R.Tuple,),
)
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
_2: R.Tuple = R.call_pure_packed(
"vm.builtin.tuple_reset_item",
params,
R.prim_value(T.int32(1)),
sinfo_args=(R.Tuple,),
)
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 3, 3, 3), dtype="float32"),
) = (lv, lv2)
R.output(gv)
return gv

mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)
expected = Expected if not consume_params else ExpectedConsumeParams
with tvm.transform.PassContext(
config={"relax.lift_transform_params.consume_params": consume_params}
):
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, expected)

names_after = [param.name_hint for param in after["main"].params]
names_expected = [param.name_hint for param in Expected["main"].params]
names_expected = [param.name_hint for param in expected["main"].params]
assert names_after == names_expected


Expand Down

0 comments on commit a0f57a0

Please sign in to comment.