Skip to content

Commit

Permalink
[Relax] Capture symbolic vars in struct info of weights (apache#16834)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored Apr 3, 2024
1 parent 9862c84 commit 54e31f3
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 15 deletions.
48 changes: 33 additions & 15 deletions src/relax/transform/rewrite_cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,31 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
if (pair.second->IsInstance<FunctionNode>()) {
// If a function has the num_input attribute, the last func->params.size() - num_inputs
// inputs are assumed to be fixed and thus they can be captured into a cuda graph.
// The symbolic variables in the struct info of the fixed inputs (weights) are also allowed
// to be captured.
// If the hints for capturing symbolic variables via
// 'relax.rewrite_cuda_graph.capture_symbolic_vars' annotation, the actual variables with
// these names are extracted from the struct info for the capturing.
const auto& func = Downcast<Function>(pair.second);
if (auto num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
for (size_t i = num_input.value().IntValue(); i < func->params.size(); ++i) {
auto num_inputs =
func->attrs.GetAttr<Integer>(attr::kNumInput).value_or(Integer(func->params.size()));
auto capture_symbolic_var_name_hints = ExtractSymbolicVarHints(func);
for (int i = 0; i < static_cast<int>(func->params.size()); ++i) {
Array<tir::Var> symbolic_vars = DefinableTIRVarsInStructInfo(
Downcast<StructInfo>(func->params[i]->struct_info_.value()));
if (i < num_inputs.IntValue()) {
for (const auto& symbolic_var : symbolic_vars) {
if (capture_symbolic_var_name_hints.count(symbolic_var->name_hint)) {
capture_symbolic_vars_.insert(symbolic_var.get());
}
}
} else {
static_vars_.insert(func->params[i].get());
for (const auto& symbolic_var : symbolic_vars) {
capture_symbolic_vars_.insert(symbolic_var.get());
}
}
}
CollectSymbolicVarHints(func);
disabled_storage_vars_ = OutputStorageCollector::Collect(func);
VisitExpr(func);
}
Expand Down Expand Up @@ -284,17 +302,16 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
}

/*!
* \brief Collect the name hints of the symbolic variables that are allowed to be captured.
* \brief Extract the name hints of the symbolic variables that are allowed to be captured
* from the function attributes.
*/
void CollectSymbolicVarHints(const Function& func) {
capture_symbolic_vars_.clear();
if (auto symbolic_vars =
func->attrs.GetAttr<Array<String>>("relax.rewrite_cuda_graph.capture_symbolic_vars")) {
for (const auto& var : symbolic_vars.value()) {
capture_symbolic_vars_.insert(var);
}
}
std::unordered_set<String> ExtractSymbolicVarHints(const Function& func) {
auto symbolic_var_names =
func->attrs.GetAttr<Array<String>>("relax.rewrite_cuda_graph.capture_symbolic_vars")
.value_or(Array<String>());
return {symbolic_var_names.begin(), symbolic_var_names.end()};
}

/*!
*\brief Start a new static region. This method should be called when encountering a
* CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on static parameters.
Expand Down Expand Up @@ -467,7 +484,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
bool is_static = true;
tir::PostOrderVisit(expr, [&](const ObjectRef& e) {
if (auto var = e.as<tir::VarNode>()) {
if (!capture_symbolic_vars_.count(var->name_hint)) {
if (!capture_symbolic_vars_.count(var)) {
is_static = false;
return;
}
Expand Down Expand Up @@ -596,8 +613,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
FunctionScope current_function_scope_;
// Variables whose buffer address is fixed
std::unordered_set<const VarNode*> static_vars_;
// The name of the variables that are allowed to be symbolic
std::unordered_set<String> capture_symbolic_vars_;
// Symbolic variables that are allowed to be captured. This can come from symbolic shapes of
// weights or hints in the function annotations.
std::unordered_set<const tir::VarNode*> capture_symbolic_vars_;
// Binding to the FuncBuilder if the binding is lifted. This is used to update the inputs/outputs
// of the lifted function when its binding is used outside.
std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;
Expand Down
88 changes: 88 additions & 0 deletions tests/python/relax/test_transform_rewrite_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,5 +1088,93 @@ def main(x: R.Tensor((8,), dtype="float32")) -> R.Tuple(R.Tensor((8,), dtype="fl
return gv


class TestStaticInputWithSymbolicShape(BaseCompare):
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor((8,), "float16"), w: R.Tensor(("m",))):
m = T.int64()
R.func_attr({"relax.force_pure": True, "num_input": 1})
storage1 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16")
alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([8]), "float16")
_ = R.call_packed("dummy", x, w, alloc1, sinfo_args=(R.Tuple,))
storage2 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16")
alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([8]), "float16")
_1 = R.call_packed("dummy", alloc1, w, alloc2, sinfo_args=(R.Tuple,))
storage3 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16")
alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([8]), "float16")
_2 = R.call_packed("dummy", alloc2, w, alloc3, sinfo_args=(R.Tuple,))
gv = (alloc3,)
return gv

@I.ir_module
class Expected:
@R.function(private=True)
def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
R.func_attr({"relax.force_pure": True})
storage1: R.Object = R.memory.alloc_storage(
R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float16")
)
storage2: R.Object = R.memory.alloc_storage(
R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float16")
)
gv: R.Tuple(R.Object, R.Object) = storage1, storage2
return gv

@R.function(private=True)
def main_cuda_graph_capture(
alloc1: R.Tensor((8,), dtype="float16"),
w: R.Tensor(("m",)),
alloc2: R.Tensor((8,), dtype="float16"),
shape_expr: R.Shape(["m"]),
) -> R.Tuple:
m = T.int64()
R.func_attr({"relax.force_pure": True})
R.call_packed("dummy", alloc1, w, alloc2, sinfo_args=(R.Tuple,))
R.tuple()
return R.tuple()

@R.function
def main(
x: R.Tensor((8,), dtype="float16"), w: R.Tensor(("m",))
) -> R.Tuple(R.Tensor((8,), dtype="float16")):
m = T.int64()
R.func_attr({"num_input": 1, "relax.force_pure": True})
cls = Expected
gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
"vm.builtin.cuda_graph.get_cached_alloc",
(cls.cuda_graph_alloc, R.prim_value(0)),
sinfo_args=(R.Tuple(R.Object, R.Object),),
)
storage1: R.Object = gv[0]
alloc1: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor(
storage1, R.prim_value(0), R.shape([8]), R.dtype("float16")
)
R.call_packed("dummy", x, w, alloc1, sinfo_args=(R.Tuple,))
storage2: R.Object = gv[1]
alloc2: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor(
storage2, R.prim_value(0), R.shape([8]), R.dtype("float16")
)
R.call_builtin_with_ctx(
"vm.builtin.cuda_graph.run_or_capture",
(
cls.main_cuda_graph_capture,
(alloc1, w, alloc2, R.shape([m])),
R.prim_value(0),
R.shape([m]),
),
sinfo_args=(R.Tuple,),
)
storage3: R.Object = R.memory.alloc_storage(
R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float16")
)
alloc3: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor(
storage3, R.prim_value(0), R.shape([8]), R.dtype("float16")
)
R.call_packed("dummy", alloc2, w, alloc3, sinfo_args=(R.Tuple,))
gv_1: R.Tuple(R.Tensor((8,), dtype="float16")) = (alloc3,)
return gv_1


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 54e31f3

Please sign in to comment.