Skip to content

Commit 45f1f8b

Browse files
rohanmukhylc
authored andcommitted
[TensorRT, BYOC] Handling a corner case in TRT RemoveDropout pass (apache#8506)
* [TensorRT, BYOC] Handling a corner case in TRT RemoveDropout pass * changing visit logic
1 parent 95a8c41 commit 45f1f8b

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tvm.relay import transform
2424
from tvm.relay.build_module import bind_params_by_name
2525
from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem
26+
from tvm.ir import Op
2627
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
2728

2829
logger = logging.getLogger("TensorRT")
@@ -1044,6 +1045,7 @@ def visit_tuple_getitem(self, op):
10441045
return visit
10451046
if (
10461047
isinstance(visit.tuple_value, Call)
1048+
and isinstance(visit.tuple_value.op, Op)
10471049
and visit.tuple_value.op.name == "nn.dropout"
10481050
and visit.index == 0
10491051
):

0 commit comments

Comments
 (0)