Skip to content

Commit 80bb96f

Browse files
author
Trevor Morris
committed
[BYOC] Prevent duplicate outputs in subgraph Tuple (apache#5320)
* Fix duplicate output in partitiongraph * Add test case * Fix test_annotated_regions with duplicate compiler_end outputs * Revert "Fix duplicate output in partitiongraph" This reverts commit e1f8ef3. * Prevent duplicate outputs in Tuple in PartitionGraph * Fix lint * Add another test case for when regions are merged, and when TupleGetItem was duplicated * Pull GetFunctionOutput out of branch, improve description of GetFunctionOutput * Use std::move for GetFunctionOutput. Fix typo with testcase name * Use tvm.transform.Sequential
1 parent 3116cbe commit 80bb96f

File tree

2 files changed

+260
-101
lines changed

2 files changed

+260
-101
lines changed

src/relay/transforms/partition_graph.cc

Lines changed: 125 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -205,99 +205,13 @@ class Partitioner : public ExprMutator {
205205
// region_function_calls is map that maintains
206206
// (each annotated regions) --> created function
207207

208-
if (region_function_calls.find(region) != region_function_calls.end()) {
209-
// This section is executed only if there are multiple outputs in the
210-
// region Thus, the function is always created and at the end there
211-
// would be a tuple node Therefore, we insert a tuple get item node.
212-
213-
// Use the already created tuple node
214-
auto sg_call = region_function_calls[region];
215-
int index = GetRetIdx(region, GetRef<Call>(call));
216-
CHECK_NE(index, -1);
217-
218-
auto tuple_get_item_ = TupleGetItem(sg_call, index);
219-
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
220-
return std::move(tuple_get_item_);
221-
} else {
222-
// First time this region is encountered in the traversal
223-
// Creating the function
224-
225-
Array<Expr> fields;
226-
227-
for (auto ret : region->GetOutputs()) {
228-
auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
229-
fields.push_back(ret_expr);
230-
}
231-
int index = GetRetIdx(region, GetRef<Call>(call));
232-
CHECK_NE(index, -1);
233-
234-
Array<Var> params;
235-
Array<Expr> param_expr;
236-
std::unordered_map<std::string, runtime::NDArray> params_bind;
237-
238-
for (auto pair : region_args[region]) {
239-
params.push_back(pair.first);
240-
if (const auto* cn = pair.second.as<ConstantNode>()) {
241-
params_bind[pair.first->name_hint()] = cn->data;
242-
} else {
243-
param_expr.push_back(pair.second);
244-
}
245-
}
246-
247-
Function global_region_func;
248-
if (region->GetOutputs().size() == 1) {
249-
// If there are only a single output; no need to add a tuple
250-
global_region_func =
251-
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
252-
} else {
253-
auto tuple = Tuple(fields);
254-
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
255-
}
256-
257-
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
258-
std::string name = target + "_" + std::to_string(region->GetID());
259-
260-
global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
261-
runtime::String(name));
262-
global_region_func =
263-
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
264-
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
265-
tvm::runtime::String(target));
266-
global_region_func =
267-
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
268-
269-
// Constant propagation
270-
if (!params_bind.empty()) {
271-
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
272-
}
273-
274-
std::string fname = name;
275-
CHECK(!module_->ContainGlobalVar(fname))
276-
<< "Global function " << fname << " already exists";
277-
// Create a global function and add it to the IRModule for the region.
278-
// This way we lift the functions that should be handled by external
279-
// codegen to the module scope and rely on the pass manager to prevent
280-
// relay function level passes (i.e. simplify inference and fusion)
281-
// optimizing it.
282-
GlobalVar glob_func(fname);
283-
module_->Add(glob_func, global_region_func);
284-
285-
// The return type of callnode is the same as the type of the
286-
// compiler_end node.
287-
auto ret = Call(glob_func, param_expr);
288-
region_function_calls[region] = ret;
289-
290-
if (region->GetOutputs().size() == 1) {
291-
// If there is only a single output; no need to add a tuplegetitem
292-
// node
293-
return std::move(ret);
294-
} else {
295-
// Add a tuplegetitem node to select this output out of many
296-
auto tuple_get_item_ = TupleGetItem(ret, index);
297-
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
298-
return std::move(tuple_get_item_);
299-
}
208+
if (region_function_calls.find(region) == region_function_calls.end()) {
209+
// First time this region is encountered in the traversal.
210+
// Creating the function.
211+
CreateFunction(region, call);
300212
}
213+
// Retrieve this particular output of function.
214+
return GetFunctionOutput(region, GetRef<Call>(call));
301215
}
302216
}
303217

@@ -456,18 +370,111 @@ class Partitioner : public ExprMutator {
456370
}
457371

458372
/*!
459-
* \brief Get the index of the return(output);
460-
* this is to be used as tuplegetitem idx
373+
* \brief This function is called first time that we encounter a compiler_end
374+
* node to create the function for the subgraph.
461375
*/
462-
int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
463-
int idx = 0;
464-
for (auto arg_ : sg->GetOutputs()) {
465-
if (arg == arg_) {
466-
return idx;
376+
void CreateFunction(AnnotatedRegion region, const CallNode* call) {
377+
// Create fields which is a unique list of outputs. Also populate
378+
// region_return_indices_ map which maps parent of compiler_end node to
379+
// corresponding index in fields.
380+
Array<Expr> fields;
381+
int i = 0;
382+
for (auto ret : region->GetOutputs()) {
383+
auto ret_node = Downcast<Call>(ret)->args[0];
384+
// Don't duplicate outputs.
385+
if (!region_return_indices_.count(region) ||
386+
!region_return_indices_[region].count(ret_node)) {
387+
auto ret_expr = VisitExpr(ret_node);
388+
fields.push_back(ret_expr);
389+
region_return_indices_[region][ret_node] = i;
390+
i++;
467391
}
468-
idx++;
469392
}
470-
return -1;
393+
394+
Array<Var> params;
395+
Array<Expr> param_expr;
396+
std::unordered_map<std::string, runtime::NDArray> params_bind;
397+
398+
for (auto pair : region_args[region]) {
399+
params.push_back(pair.first);
400+
if (const auto* cn = pair.second.as<ConstantNode>()) {
401+
params_bind[pair.first->name_hint()] = cn->data;
402+
} else {
403+
param_expr.push_back(pair.second);
404+
}
405+
}
406+
407+
Function global_region_func;
408+
if (fields.size() == 1) {
409+
// If there are only a single output; no need to add a tuple
410+
global_region_func =
411+
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
412+
} else {
413+
auto tuple = Tuple(fields);
414+
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
415+
}
416+
417+
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
418+
std::string name = target + "_" + std::to_string(region->GetID());
419+
420+
global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
421+
runtime::String(name));
422+
global_region_func =
423+
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
424+
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
425+
tvm::runtime::String(target));
426+
global_region_func =
427+
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
428+
429+
// Constant propagation
430+
if (!params_bind.empty()) {
431+
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
432+
}
433+
434+
std::string fname = name;
435+
CHECK(!module_->ContainGlobalVar(fname))
436+
<< "Global function " << fname << " already exists";
437+
// Create a global function and add it to the IRModule for the region.
438+
// This way we lift the functions that should be handled by external
439+
// codegen to the module scope and rely on the pass manager to prevent
440+
// relay function level passes (i.e. simplify inference and fusion)
441+
// optimizing it.
442+
GlobalVar glob_func(fname);
443+
module_->Add(glob_func, global_region_func);
444+
445+
// The return type of callnode is the same as the type of the
446+
// compiler_end node.
447+
auto ret = Call(glob_func, param_expr);
448+
region_function_calls[region] = ret;
449+
}
450+
451+
/*!
452+
* \brief Get the return(output) of the function for compiler end node "end_arg".
453+
* This will return either a Call (for a function with a single output) or a
454+
* TupleGetItem (for a function with multiple outputs).
455+
*/
456+
Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
457+
Expr arg = Downcast<Call>(end_arg)->args[0];
458+
// Function has one output.
459+
if (region_return_indices_[region].size() == 1) {
460+
return region_function_calls[region];
461+
}
462+
// Function has multiple outputs.
463+
// Use already made TupleGetItem.
464+
if (region_return_tuplegetitem_.count(region) &&
465+
region_return_tuplegetitem_[region].count(arg)) {
466+
return region_return_tuplegetitem_[region][arg];
467+
}
468+
// Create new TupleGetItem.
469+
CHECK(region_return_indices_.count(region) &&
470+
region_return_indices_[region].count(arg));
471+
int index = region_return_indices_[region][arg];
472+
473+
auto func_call = region_function_calls[region];
474+
auto tuple_get_item_ = TupleGetItem(func_call, index);
475+
tuple_get_item_->checked_type_ = arg->checked_type_;
476+
region_return_tuplegetitem_[region][arg] = tuple_get_item_;
477+
return std::move(tuple_get_item_);
471478
}
472479

473480
/*!
@@ -485,6 +492,23 @@ class Partitioner : public ExprMutator {
485492
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
486493
region_args;
487494

495+
/*!
496+
* \brief This map maintains the index of an output in the subgraph function
497+
* for a given region. If there are multiple entries for a region, then the
498+
* function has a tuple of multiple outputs for its return.
499+
*/
500+
using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash, ObjectEqual>;
501+
std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, ObjectEqual>
502+
region_return_indices_;
503+
504+
/*!
505+
* \brief This map holds already created TupleGetItem nodes for accessing
506+
* outputs of a function.
507+
*/
508+
using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, ObjectHash, ObjectEqual>;
509+
std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, ObjectEqual>
510+
region_return_tuplegetitem_;
511+
488512
/*!
489513
* \brief Each region set is associated with a function in the module.
490514
* This map maintains the mapping between regionsets and the function it

0 commit comments

Comments
 (0)