Skip to content

Commit 6b7f17c

Browse files
committed
changing visit logic
1 parent 6aba042 commit 6b7f17c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from tvm import relay
2323
from tvm.relay import transform
2424
from tvm.relay.build_module import bind_params_by_name
25-
from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem, Let
25+
from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem
26+
from tvm.ir import Op
2627
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
2728

2829
logger = logging.getLogger("TensorRT")
@@ -1033,10 +1034,9 @@ def visit_tuple_getitem(self, op):
10331034
visit = super().visit_tuple_getitem(op)
10341035
if visit.index != 0:
10351036
return visit
1036-
if isinstance(visit.tuple_value, Call) and isinstance(visit.tuple_value.op, Let):
1037-
return visit
10381037
if (
10391038
isinstance(visit.tuple_value, Call)
1039+
and isinstance(visit.tuple_value.op, Op)
10401040
and visit.tuple_value.op.name == "nn.dropout"
10411041
and visit.index == 0
10421042
):

0 commit comments

Comments
 (0)