@@ -205,99 +205,13 @@ class Partitioner : public ExprMutator {
205205 // region_function_calls is map that maintains
206206 // (each annotated regions) --> created function
207207
208- if (region_function_calls.find (region) != region_function_calls.end ()) {
209- // This section is executed only if there are multiple outputs in the
210- // region Thus, the function is always created and at the end there
211- // would be a tuple node Therefore, we insert a tuple get item node.
212-
213- // Use the already created tuple node
214- auto sg_call = region_function_calls[region];
215- int index = GetRetIdx (region, GetRef<Call>(call));
216- CHECK_NE (index, -1 );
217-
218- auto tuple_get_item_ = TupleGetItem (sg_call, index);
219- tuple_get_item_->checked_type_ = GetRef<Call>(call)->args [0 ]->checked_type_ ;
220- return std::move (tuple_get_item_);
221- } else {
222- // First time this region is encountered in the traversal
223- // Creating the function
224-
225- Array<Expr> fields;
226-
227- for (auto ret : region->GetOutputs ()) {
228- auto ret_expr = VisitExpr (Downcast<Call>(ret)->args [0 ]);
229- fields.push_back (ret_expr);
230- }
231- int index = GetRetIdx (region, GetRef<Call>(call));
232- CHECK_NE (index, -1 );
233-
234- Array<Var> params;
235- Array<Expr> param_expr;
236- std::unordered_map<std::string, runtime::NDArray> params_bind;
237-
238- for (auto pair : region_args[region]) {
239- params.push_back (pair.first );
240- if (const auto * cn = pair.second .as <ConstantNode>()) {
241- params_bind[pair.first ->name_hint ()] = cn->data ;
242- } else {
243- param_expr.push_back (pair.second );
244- }
245- }
246-
247- Function global_region_func;
248- if (region->GetOutputs ().size () == 1 ) {
249- // If there are only a single output; no need to add a tuple
250- global_region_func =
251- Function (params, fields[0 ], call->args [0 ]->checked_type_ , {}, DictAttrs ());
252- } else {
253- auto tuple = Tuple (fields);
254- global_region_func = Function (params, tuple, tuple->checked_type_ , {}, DictAttrs ());
255- }
256-
257- std::string target = call->attrs .as <CompilerAttrs>()->compiler ;
258- std::string name = target + " _" + std::to_string (region->GetID ());
259-
260- global_region_func = WithAttr (std::move (global_region_func), tvm::attr::kGlobalSymbol ,
261- runtime::String (name));
262- global_region_func =
263- WithAttr (std::move (global_region_func), attr::kPrimitive , tvm::Integer (1 ));
264- global_region_func = WithAttr (std::move (global_region_func), attr::kCompiler ,
265- tvm::runtime::String (target));
266- global_region_func =
267- WithAttr (std::move (global_region_func), attr::kInline , tvm::Integer (1 ));
268-
269- // Constant propagation
270- if (!params_bind.empty ()) {
271- global_region_func = backend::BindParamsByName (global_region_func, params_bind);
272- }
273-
274- std::string fname = name;
275- CHECK (!module_->ContainGlobalVar (fname))
276- << " Global function " << fname << " already exists" ;
277- // Create a global function and add it to the IRModule for the region.
278- // This way we lift the functions that should be handled by external
279- // codegen to the module scope and rely on the pass manager to prevent
280- // relay function level passes (i.e. simplify inference and fusion)
281- // optimizing it.
282- GlobalVar glob_func (fname);
283- module_->Add (glob_func, global_region_func);
284-
285- // The return type of callnode is the same as the type of the
286- // compiler_end node.
287- auto ret = Call (glob_func, param_expr);
288- region_function_calls[region] = ret;
289-
290- if (region->GetOutputs ().size () == 1 ) {
291- // If there is only a single output; no need to add a tuplegetitem
292- // node
293- return std::move (ret);
294- } else {
295- // Add a tuplegetitem node to select this output out of many
296- auto tuple_get_item_ = TupleGetItem (ret, index);
297- tuple_get_item_->checked_type_ = GetRef<Call>(call)->args [0 ]->checked_type_ ;
298- return std::move (tuple_get_item_);
299- }
208+ if (region_function_calls.find (region) == region_function_calls.end ()) {
209+ // First time this region is encountered in the traversal.
210+ // Creating the function.
211+ CreateFunction (region, call);
300212 }
213+ // Retrieve this particular output of function.
214+ return GetFunctionOutput (region, GetRef<Call>(call));
301215 }
302216 }
303217
@@ -456,18 +370,111 @@ class Partitioner : public ExprMutator {
456370 }
457371
458372 /* !
459- * \brief Get the index of the return(output);
460- * this is to be used as tuplegetitem idx
373+ * \brief This function is called first time that we encounter a compiler_end
374+ * node to create the function for the subgraph.
461375 */
462- int GetRetIdx (AnnotatedRegion sg, const Expr& arg) {
463- int idx = 0 ;
464- for (auto arg_ : sg->GetOutputs ()) {
465- if (arg == arg_) {
466- return idx;
376+ void CreateFunction (AnnotatedRegion region, const CallNode* call) {
377+ // Create fields which is a unique list of outputs. Also populate
378+ // region_return_indices_ map which maps parent of compiler_end node to
379+ // corresponding index in fields.
380+ Array<Expr> fields;
381+ int i = 0 ;
382+ for (auto ret : region->GetOutputs ()) {
383+ auto ret_node = Downcast<Call>(ret)->args [0 ];
384+ // Don't duplicate outputs.
385+ if (!region_return_indices_.count (region) ||
386+ !region_return_indices_[region].count (ret_node)) {
387+ auto ret_expr = VisitExpr (ret_node);
388+ fields.push_back (ret_expr);
389+ region_return_indices_[region][ret_node] = i;
390+ i++;
467391 }
468- idx++;
469392 }
470- return -1 ;
393+
394+ Array<Var> params;
395+ Array<Expr> param_expr;
396+ std::unordered_map<std::string, runtime::NDArray> params_bind;
397+
398+ for (auto pair : region_args[region]) {
399+ params.push_back (pair.first );
400+ if (const auto * cn = pair.second .as <ConstantNode>()) {
401+ params_bind[pair.first ->name_hint ()] = cn->data ;
402+ } else {
403+ param_expr.push_back (pair.second );
404+ }
405+ }
406+
407+ Function global_region_func;
408+ if (fields.size () == 1 ) {
409+ // If there are only a single output; no need to add a tuple
410+ global_region_func =
411+ Function (params, fields[0 ], call->args [0 ]->checked_type_ , {}, DictAttrs ());
412+ } else {
413+ auto tuple = Tuple (fields);
414+ global_region_func = Function (params, tuple, tuple->checked_type_ , {}, DictAttrs ());
415+ }
416+
417+ std::string target = call->attrs .as <CompilerAttrs>()->compiler ;
418+ std::string name = target + " _" + std::to_string (region->GetID ());
419+
420+ global_region_func = WithAttr (std::move (global_region_func), tvm::attr::kGlobalSymbol ,
421+ runtime::String (name));
422+ global_region_func =
423+ WithAttr (std::move (global_region_func), attr::kPrimitive , tvm::Integer (1 ));
424+ global_region_func = WithAttr (std::move (global_region_func), attr::kCompiler ,
425+ tvm::runtime::String (target));
426+ global_region_func =
427+ WithAttr (std::move (global_region_func), attr::kInline , tvm::Integer (1 ));
428+
429+ // Constant propagation
430+ if (!params_bind.empty ()) {
431+ global_region_func = backend::BindParamsByName (global_region_func, params_bind);
432+ }
433+
434+ std::string fname = name;
435+ CHECK (!module_->ContainGlobalVar (fname))
436+ << " Global function " << fname << " already exists" ;
437+ // Create a global function and add it to the IRModule for the region.
438+ // This way we lift the functions that should be handled by external
439+ // codegen to the module scope and rely on the pass manager to prevent
440+ // relay function level passes (i.e. simplify inference and fusion)
441+ // optimizing it.
442+ GlobalVar glob_func (fname);
443+ module_->Add (glob_func, global_region_func);
444+
445+ // The return type of callnode is the same as the type of the
446+ // compiler_end node.
447+ auto ret = Call (glob_func, param_expr);
448+ region_function_calls[region] = ret;
449+ }
450+
451+ /* !
452+ * \brief Get the return(output) of the function for compiler end node "end_arg".
453+ * This will return either a Call (for a function with a single output) or a
454+ * TupleGetItem (for a function with multiple outputs).
455+ */
456+ Expr GetFunctionOutput (AnnotatedRegion region, const Expr& end_arg) {
457+ Expr arg = Downcast<Call>(end_arg)->args [0 ];
458+ // Function has one output.
459+ if (region_return_indices_[region].size () == 1 ) {
460+ return region_function_calls[region];
461+ }
462+ // Function has multiple outputs.
463+ // Use already made TupleGetItem.
464+ if (region_return_tuplegetitem_.count (region) &&
465+ region_return_tuplegetitem_[region].count (arg)) {
466+ return region_return_tuplegetitem_[region][arg];
467+ }
468+ // Create new TupleGetItem.
469+ CHECK (region_return_indices_.count (region) &&
470+ region_return_indices_[region].count (arg));
471+ int index = region_return_indices_[region][arg];
472+
473+ auto func_call = region_function_calls[region];
474+ auto tuple_get_item_ = TupleGetItem (func_call, index);
475+ tuple_get_item_->checked_type_ = arg->checked_type_ ;
476+ region_return_tuplegetitem_[region][arg] = tuple_get_item_;
477+ return std::move (tuple_get_item_);
471478 }
472479
473480 /* !
@@ -485,6 +492,23 @@ class Partitioner : public ExprMutator {
485492 std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
486493 region_args;
487494
495+ /* !
496+ * \brief This map maintains the index of an output in the subgraph function
497+ * for a given region. If there are multiple entries for a region, then the
498+ * function has a tuple of multiple outputs for its return.
499+ */
500+ using RegionRetIndexMap = std::unordered_map<Expr, int , ObjectHash, ObjectEqual>;
501+ std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, ObjectEqual>
502+ region_return_indices_;
503+
504+ /* !
505+ * \brief This map holds already created TupleGetItem nodes for accessing
506+ * outputs of a function.
507+ */
508+ using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, ObjectHash, ObjectEqual>;
509+ std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, ObjectEqual>
510+ region_return_tuplegetitem_;
511+
488512 /* !
489513 * \brief Each region set is associated with a function in the module.
490514 * This map maintains the mapping between regionsets and the function it
0 commit comments