Skip to content

Commit af66383

Browse files
masahiMasahiro Masuda
authored andcommitted
fix for reference handling and isolated cases
1 parent d17aabb commit af66383

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

src/relay/pass/fuse_ops.cc

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -261,13 +261,22 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
261261
}
262262

263263
void VisitExpr_(const TupleGetItemNode* op) final {
264-
CHECK(graph_.node_map.count(op));
265-
Node* node = graph_.node_map.at(op);
266-
node->pattern = kInjective;
267-
if (op->tuple->checked_type().as<TupleTypeNode>()) {
268-
this->Update(op->tuple, node, kInjective);
269-
} else {
264+
auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>();
265+
CHECK(tuple_type);
266+
bool has_reference = false;
267+
for (auto ty : tuple_type->fields) {
268+
if (auto ref_ty = ty.as<RefTypeNode>()) {
269+
has_reference = true;
270+
break;
271+
}
272+
}
273+
if (has_reference) {
270274
this->Update(op->tuple, nullptr, kOpaque);
275+
} else {
276+
CHECK(graph_.node_map.count(op));
277+
Node* node = graph_.node_map.at(op);
278+
node->pattern = kInjective;
279+
this->Update(op->tuple, node, kInjective);
271280
}
272281
ExprVisitor::VisitExpr_(op);
273282
this->AddNode(op);
@@ -815,10 +824,16 @@ class FuseMutator : private ExprMutator {
815824
}
816825

817826
Expr VisitExpr_(const TupleGetItemNode* tuple_get) {
818-
auto new_node = TupleGetItemNode::make(this->Mutate(tuple_get->tuple), tuple_get->index);
819827
auto* ret_group = gmap_.at(tuple_get)->FindRoot();
828+
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
829+
auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
820830
if (ret_group == gmap_.at(tuple_get)) {
821-
// unlike the tuple case above, this node should never be isolated
831+
if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
832+
// Isolated. This case occurs when tuple is created by an Opaque op
833+
// e.g. multibox_transform_loc
834+
return ExprMutator::VisitExpr_(tuple_get);
835+
}
836+
// A new function whose output is a tuple field access
822837
return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node);
823838
}
824839
// This is an intermediate node in the group

0 commit comments

Comments
 (0)