Skip to content

Commit

Permalink
[Relax] Enable capturing symbolic shapes in cuda graph (apache#16815)
Browse files Browse the repository at this point in the history
* [Relax] Enable capturing symbolic shapes in cuda graph

* Add Bind sinfo util

* Bind ret sinfo

* address comments

* add comments

* fix

* update test
  • Loading branch information
vinx13 authored Mar 30, 2024
1 parent eb4175b commit ef32a61
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 31 deletions.
7 changes: 7 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ namespace relax {
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map = {});

/*!
* \brief Bind the symbolic variables to a StructInfo. This is a helper function usually called by
* other pass functions to help optimizations.
*/
TVM_DLL StructInfo Bind(const StructInfo& sinfo,
const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map);

/*!
* \brief Infer a binding map for symbolic variables
*
Expand Down
161 changes: 139 additions & 22 deletions src/relax/transform/rewrite_cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
#include <tvm/relax/backend.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/stmt_functor.h>

#include "../../support/arena.h"
#include "../../support/ordered_set.h"
Expand Down Expand Up @@ -82,6 +84,8 @@ struct LiftedFunctionRewritePlan {
std::vector<const VarNode*> outputs;
// The corresponding binding vars in the original function of the inputs of the lifted function
std::vector<const VarNode*> inputs;
// The tir vars in the original function that are propagated to the lifted function
Optional<ShapeExpr> propogated_tir_vars = NullOpt;
};

/*! \brief Builder of the lifted function for cuda graph capturing or allocations */
Expand All @@ -98,6 +102,11 @@ class FuncBuilder : public ExprMutator {
* \param var The variable to mark as input
*/
void MarkInput(const VarNode* var) { inputs_.push_back(var); }
/*!
* \brief Mark a TIR variable as the ShapeExpr input of the new function.
* \param var The variable to mark as input
*/
void MarkShapeExprInput(const tir::VarNode* var) { shape_expr_inputs_.push_back(var); }
/*!
* \brief Mark a variable as the output of the new function. The variable must be the LHS of an
* existing binding in the new function.
Expand All @@ -111,12 +120,27 @@ class FuncBuilder : public ExprMutator {
/*! \brief Build the new function */
Function Build() {
Array<Var> params;
Optional<Var> shape_expr = NullOpt;
if (shape_expr_inputs_.size()) {
Array<PrimExpr> tir_vars;
for (const auto* var : shape_expr_inputs_) {
auto new_var = GetRef<tir::Var>(var).copy_with_suffix("");
tir_var_remap_.Set(GetRef<tir::Var>(var), new_var);
tir_vars.push_back(new_var);
}
shape_expr = Var("shape_expr", ShapeStructInfo(tir_vars));
}
// Set up the parameters
for (const auto* input : inputs_) {
auto new_var = Var(input->name_hint(), Downcast<Optional<StructInfo>>(input->struct_info_));
auto new_var = Var(
input->name_hint(),
VisitExprDepStructInfoField(Downcast<Optional<StructInfo>>(input->struct_info_).value()));
var_remap_[input->vid] = new_var;
params.push_back(new_var);
}
if (shape_expr) {
params.push_back(shape_expr.value());
}
// Emit the function body
builder_->BeginBindingBlock();
for (const auto* binding : bindings_) {
Expand All @@ -137,9 +161,13 @@ class FuncBuilder : public ExprMutator {
return func;
}

PrimExpr VisitPrimExpr(const PrimExpr& expr) { return tir::Substitute(expr, tir_var_remap_); }

support::OrderedSet<const VarNode*> inputs_;
support::OrderedSet<const VarNode*> outputs_;
support::OrderedSet<const tir::VarNode*> shape_expr_inputs_;
std::vector<const VarBindingNode*> bindings_;
Map<tir::Var, PrimExpr> tir_var_remap_;
};

/*!
Expand All @@ -159,6 +187,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
static_vars_.insert(func->params[i].get());
}
}
CollectSymbolicVarHints(func);
VisitExpr(func);
}
}
Expand All @@ -174,6 +203,13 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
for (const auto* binding : region->bindings_) {
plan.lifted_bindings.insert(binding->var.get());
}
if (region->shape_expr_inputs_.size()) {
Array<PrimExpr> tir_vars;
for (const auto* var : region->shape_expr_inputs_) {
tir_vars.push_back(GetRef<PrimExpr>(var));
}
plan.propogated_tir_vars = ShapeExpr(tir_vars);
}
plan.inputs.assign(region->inputs_.begin(), region->inputs_.end());
plan.outputs.assign(region->outputs_.begin(), region->outputs_.end());
return plan;
Expand All @@ -189,6 +225,18 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
return plans;
}

/*!
* \brief Collect the name hints of the symbolic variables that are allowed to be captured.
*/
void CollectSymbolicVarHints(const Function& func) {
capture_symbolic_vars_.clear();
if (auto symbolic_vars =
func->attrs.GetAttr<Array<String>>("relax.rewrite_cuda_graph.capture_symbolic_vars")) {
for (const auto& var : symbolic_vars.value()) {
capture_symbolic_vars_.insert(var);
}
}
}
/*!
*\brief Start a new static region. This method should be called when encountering a
* CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on static parameters.
Expand Down Expand Up @@ -239,8 +287,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
// Check whether the call can be lifted to the capture function. It requires all the arguments
// to be static and the call to be a kernel launch or a pure operation (e.g. memory view).
std::vector<const VarNode*> args;
std::vector<const tir::VarNode*> tir_vars;
bool is_all_static = [&]() {
if (!IsStatic(call->args, &args)) {
if (!IsStatic(call->args, &args, &tir_vars)) {
return false;
}
if (call_gv != nullptr && !call_prim_func) {
Expand Down Expand Up @@ -276,15 +325,16 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
StartRegion();
}
AddStaticBinding(binding, /*is_alloc_storage=*/false);
MarkAsFuncInput(args);
MarkAsFuncInput(args, tir_vars);
} else {
EndRegion();
}

MarkAsFuncOutput(args);
}

void MarkAsFuncInput(const std::vector<const VarNode*>& vars) {
void MarkAsFuncInput(const std::vector<const VarNode*>& vars,
const std::vector<const tir::VarNode*>& tir_vars = {}) {
if (current_.capture_builder == nullptr) {
return;
}
Expand All @@ -294,6 +344,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
current_.capture_builder->MarkInput(var);
}
}
for (const tir::VarNode* tir_var : tir_vars) {
current_.capture_builder->MarkShapeExprInput(tir_var);
}
}

void MarkAsFuncOutput(const std::vector<const VarNode*>& vars) {
Expand Down Expand Up @@ -321,9 +374,10 @@ class CUDAGraphRewritePlanner : public ExprVisitor {

void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) final {
std::vector<const VarNode*> args;
if (IsStatic(tuple->fields, &args)) {
std::vector<const tir::VarNode*> tir_vars;
if (IsStatic(tuple->fields, &args, &tir_vars)) {
AddStaticBinding(binding, false);
MarkAsFuncInput(args);
MarkAsFuncInput(args, tir_vars);
} else {
EndRegion();
}
Expand All @@ -343,48 +397,83 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
}

bool IsStatic(const PrimExpr& expr,
[[maybe_unused]] std::vector<const VarNode*>* vars_collector = nullptr) {
return expr->IsInstance<tir::IntImmNode>() || expr->IsInstance<tir::FloatImmNode>();
[[maybe_unused]] std::vector<const VarNode*>* vars_collector = nullptr,
std::vector<const tir::VarNode*>* tir_vars_collector = nullptr) {
bool is_static = true;
tir::PostOrderVisit(expr, [&](const ObjectRef& e) {
if (auto var = e.as<tir::VarNode>()) {
if (!capture_symbolic_vars_.count(var->name_hint)) {
is_static = false;
return;
}
if (tir_vars_collector != nullptr) {
tir_vars_collector->push_back(var);
}
}
});
return is_static;
}

bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector = nullptr) {
bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector = nullptr,
std::vector<const tir::VarNode*>* tir_vars_collector = nullptr) {
if (expr->IsInstance<ConstantNode>() || expr->IsInstance<DataTypeImmNode>() ||
expr->IsInstance<StringImmNode>()) {
expr->IsInstance<StringImmNode>() || expr->IsInstance<GlobalVarNode>()) {
return true;
}
if (const auto* prim_value = expr.as<PrimValueNode>()) {
return IsStatic(prim_value->value, vars_collector);
return IsStatic(prim_value->value, vars_collector, tir_vars_collector);
}
if (const auto* var = expr.as<VarNode>()) {
if (vars_collector != nullptr) {
vars_collector->push_back(var);
}
return static_vars_.count(var);
// recursively check the struct info to collect the symbolic TIR vars
return static_vars_.count(var) && IsStatic(Downcast<StructInfo>(var->struct_info_.value()),
vars_collector, tir_vars_collector);
}

if (const auto* shape = expr.as<ShapeExprNode>()) {
return IsStatic(shape->values, vars_collector);
return IsStatic(shape->values, vars_collector, tir_vars_collector);
}
if (const auto* tuple = expr.as<TupleNode>()) {
return IsStatic(tuple->fields, vars_collector);
return IsStatic(tuple->fields, vars_collector, tir_vars_collector);
}
return false;
}

template <typename T>
bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>* vars_collector = nullptr) {
bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>* vars_collector = nullptr,
std::vector<const tir::VarNode*>* tir_vars_collector = nullptr) {
bool result = true;
for (const auto& expr : exprs) {
// If vars_collector is provided, we will collect all the vars in the exprs and we should
// not perform short-circuiting.
result &= IsStatic(expr, vars_collector);
if (!vars_collector && !result) {
result &= IsStatic(expr, vars_collector, tir_vars_collector);
if (vars_collector == nullptr && tir_vars_collector == nullptr && !result) {
return false;
}
}
return result;
}

bool IsStatic(const StructInfo& sinfo, std::vector<const VarNode*>* vars_collector = nullptr,
std::vector<const tir::VarNode*>* tir_vars_collector = nullptr) {
if (const auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
if (auto shape = tensor_sinfo->GetShape()) {
return IsStatic(shape.value(), vars_collector, tir_vars_collector);
}
} else if (const auto* shape_sinfo = sinfo.as<ShapeStructInfoNode>()) {
if (shape_sinfo->values) {
return IsStatic(shape_sinfo->values.value(), vars_collector, tir_vars_collector);
}
} else if (const auto* tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
return IsStatic(tuple_sinfo->fields, vars_collector, tir_vars_collector);
} else if (sinfo.as<ObjectStructInfoNode>() || sinfo.as<PrimStructInfoNode>()) {
return true;
}
return false;
}

private:
bool IsStaticAllocStorage(const VarBindingNode* binding) {
// Check if the allocation has constant shape
Expand Down Expand Up @@ -431,6 +520,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
Scope current_;
// Variables whose buffer address is fixed
std::unordered_set<const VarNode*> static_vars_;
// The name of the variables that are allowed to be symbolic
std::unordered_set<String> capture_symbolic_vars_;
// Binding to the FuncBuilder if the binding is lifted. This is used to update the inputs/outputs
// of the lifted function when its binding is used outside.
std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;
Expand Down Expand Up @@ -475,22 +566,48 @@ class CUDAGraphRewriter : public ExprMutator {
auto gv_func =
builder_->AddFunction(plan.func, plan.is_alloc ? "cuda_graph_alloc" : "cuda_graph_capture");
if (plan.is_alloc) {
// Storage allocation should be fully static and shouldn't depend on any symbolic variables.
ICHECK(!plan.propogated_tir_vars.defined());
ICHECK(plan.inputs.empty());
launch_subgraph =
Call(call_builtin_with_ctx_op,
{builtin_get_cached_alloc,
Tuple({gv_func, PrimValue(IntImm(DataType::Int(64), index_alloc_++))})},
Attrs(), {plan.func->ret_struct_info});
} else {
StructInfo call_sinfo = plan.func->ret_struct_info;
// Arguments of the lifted function
Array<Expr> args;
for (const auto& arg : plan.inputs) {
args.push_back(VisitExpr_(arg));
}
launch_subgraph = Call(
call_builtin_with_ctx_op,
{builtin_run_or_capture,
Tuple({gv_func, Tuple(args), PrimValue(IntImm(DataType::Int(64), index_capture_++))})},
Attrs(), {plan.func->ret_struct_info});
if (plan.propogated_tir_vars.defined()) {
ShapeExpr propogated_tir_vars = plan.propogated_tir_vars.value();
args.push_back(propogated_tir_vars);
// The ret_struct_info of the lifted function can contain symbolic variables. We need to
// bind the symbolic parameters to the actual values.
const auto& shape_expr = plan.func->params.back();
auto symbolic_params =
Downcast<ShapeStructInfo>(shape_expr->struct_info_.value())->values.value();
Map<tir::Var, PrimExpr> tir_var_remap;
ICHECK_EQ(symbolic_params.size(), propogated_tir_vars->values.size());
for (int i = 0; i < static_cast<int>(symbolic_params.size()); ++i) {
tir_var_remap.Set(Downcast<tir::Var>(symbolic_params[i]), propogated_tir_vars->values[i]);
}
call_sinfo = Bind(call_sinfo, tir_var_remap);
}
// Arguments of builtin_run_or_capture
Array<Expr> tuple_arg_fields{gv_func, Tuple(args),
PrimValue(IntImm(DataType::Int(64), index_capture_++))};
if (plan.propogated_tir_vars.defined()) {
// The shape expr is explicitly passed twice, one as the last argument of the lifted
// function, one as the last argument of builtin_run_or_capture as the cache key. Explicitly
// passing it twice simplifies the handling during the capture phase.
tuple_arg_fields.push_back(plan.propogated_tir_vars.value());
}
launch_subgraph =
Call(call_builtin_with_ctx_op, {builtin_run_or_capture, Tuple(tuple_arg_fields)}, Attrs(),
{call_sinfo});
}
Expr ret_value = builder_->Emit(launch_subgraph);
for (int i = 0; i < static_cast<int>(plan.outputs.size()); ++i) {
Expand Down
4 changes: 4 additions & 0 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
return ExprBinder(binds, symbolic_var_map).VisitExpr(expr);
}

StructInfo Bind(const StructInfo& sinfo, const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map) {
return ExprBinder({}, symbolic_var_map).VisitExprDepStructInfoField(sinfo);
}

tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
const tvm::Map<relax::Var, relax::Expr>& relax_var_remap, arith::Analyzer* analyzer) {
tvm::Map<tir::Var, PrimExpr> tir_var_remap;
Expand Down
Loading

0 comments on commit ef32a61

Please sign in to comment.