23
23
from tvm .relay import transform
24
24
from tvm .relay .build_module import bind_params_by_name
25
25
from tvm .relay .expr import Call , Constant , Tuple , GlobalVar , Var , TupleGetItem
26
- from tvm .relay .expr_functor import ExprMutator
26
+ from tvm .relay .expr_functor import ExprMutator , ExprVisitor
27
27
28
28
logger = logging .getLogger ("TensorRT" )
29
29
@@ -173,7 +173,7 @@ def check_dynamism(args, op_name):
173
173
"""
174
174
for arg in args :
175
175
if isinstance (arg , (Call , Var , Constant , TupleGetItem )):
176
- for dim_shape in arg .checked_type .shape :
176
+ for dim_shape in arg .checked_type .shape [ 1 :] :
177
177
if isinstance (dim_shape , tvm .tir .expr .Any ):
178
178
return True
179
179
elif isinstance (arg , Tuple ):
@@ -198,6 +198,21 @@ def _func_wrapper(expr):
198
198
if any ([x .checked_type .dtype != "float32" for x in args ]):
199
199
logger .info ("Only float32 inputs are supported for TensorRT." )
200
200
return False
201
+ if op_name == "multiply" :
202
+ shapes = [
203
+ [
204
+ int (x ) if not isinstance (x , tvm .tir .expr .Any ) else - 1
205
+ for x in arg .checked_type .shape
206
+ ]
207
+ for arg in args
208
+ ]
209
+ # Batched multiply operations don't work in implicit batch mode. The following shapes
210
+ # have been excluded because they occur in PT MaskRCNN model. The long term solution is
211
+ # to switch to explicit batch mode after performance regressions are solved.
212
+ if all (
213
+ [list (map (int , shape )) in [[300 , 64 , 7 , 7 ], [300 , 1 , 1 , 1 ]] for shape in shapes ]
214
+ ):
215
+ return False
201
216
return checker (attrs , args , op_name )
202
217
203
218
return _func_wrapper
@@ -292,19 +307,26 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
292
307
"""Check if add is supported by TensorRT."""
293
308
294
309
args = expr .args
310
+
311
+ shapes = [
312
+ [int (x ) if not isinstance (x , tvm .tir .expr .Any ) else - 1 for x in arg .checked_type .shape ]
313
+ for arg in args
314
+ ]
315
+
295
316
# RelayVM + TRT doesn't support scalar addition yet.
296
- for arg in args :
297
- if not arg . checked_type . shape :
317
+ for shape in shapes :
318
+ if len ( shape ) < 1 :
298
319
return False
320
+
299
321
if any ([x .checked_type .dtype != "float32" for x in args ]):
300
322
logger .info ("Only float32 inputs are supported for TensorRT." )
301
323
return False
302
324
if (
303
325
not get_tensorrt_use_implicit_batch_mode ()
304
326
and (isinstance (args [0 ], Constant ) or isinstance (args [1 ], Constant ))
305
- and args [0 ]. checked_type . shape [0 ] == args [1 ]. checked_type . shape [0 ]
306
- and args [0 ]. checked_type . shape [0 ] != 1
307
- and (len (args [0 ]. checked_type . shape ) > 3 or len (args [1 ]. checked_type . shape ) > 3 )
327
+ and shapes [0 ][0 ] == shapes [1 ][0 ]
328
+ and shapes [0 ][0 ] != 1
329
+ and (len (shapes [0 ]) > 3 or len (shapes [1 ]) > 3 )
308
330
):
309
331
logger .info ("add: bug in TRT with adding batched constants." )
310
332
return False
@@ -592,11 +614,35 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable
592
614
logger .info ("reshape: new shape dims must be explicit." )
593
615
return False
594
616
if get_tensorrt_use_implicit_batch_mode ():
595
- shape = list ( map ( int , args [0 ].checked_type .shape ))
596
- new_shape = list ( map ( int , attrs .newshape ))
617
+ shape = args [0 ].checked_type .shape
618
+ new_shape = attrs .newshape
597
619
if len (new_shape ) == 0 or len (shape ) == 0 :
598
620
logger .info ("reshape: Can't reshape to or from scalar." )
599
621
return False
622
+
623
+ dynamic_reshape = any ([isinstance (x , tvm .tir .expr .Any ) for x in shape ])
624
+
625
+ if dynamic_reshape :
626
+ # Make sure that the batch dim is unmodified.
627
+ if int (new_shape [0 ]) < 0 :
628
+ for shape_val , new_shape_val in enumerate (shape [1 :], new_shape [1 :]):
629
+ if not (
630
+ isinstance (shape_val , int )
631
+ and isinstance (new_shape_val , int )
632
+ and int (shape_val ) == int (new_shape_val )
633
+ ):
634
+ return False
635
+ elif int (new_shape [0 ]) > 0 :
636
+ if not (
637
+ isinstance (shape [0 ], int )
638
+ and isinstance (new_shape [0 ], int )
639
+ and int (shape [0 ]) == int (new_shape [0 ])
640
+ ):
641
+ return False
642
+ return True
643
+ shape = list (map (int , shape ))
644
+ new_shape = list (map (int , new_shape ))
645
+
600
646
# TRT cannot modify batch dimension.
601
647
original_volume = np .prod (shape )
602
648
# First, resolve 0.
@@ -607,6 +653,7 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable
607
653
for i , value in enumerate (new_shape ):
608
654
if value == - 1 :
609
655
new_shape [i ] = original_volume // np .prod ([x for x in new_shape if x != - 1 ])
656
+ # Remove batch dimension and see if volumes match
610
657
if shape [0 ] != new_shape [0 ]:
611
658
logger .info ("reshape: can't modify batch dimension." )
612
659
return False
@@ -795,31 +842,73 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
795
842
return True
796
843
797
844
845
+ class IsComputeIntensiveGraph (ExprVisitor ):
846
+ """
847
+ Visits the Graph recursively and checks if it contains compute heavy ops like convolutions and
848
+ its transpose, dense and batch mat-mul.
849
+ """
850
+
851
+ def __init__ (self ):
852
+ ExprVisitor .__init__ (self )
853
+ self .is_compute_intensive = False
854
+
855
+ def visit_call (self , call ):
856
+ compute_intensive_ops = set (
857
+ [
858
+ "nn.conv2d" ,
859
+ "nn.conv2d_transpose" ,
860
+ "nn.conv3d" ,
861
+ "nn.conv3d_transpose" ,
862
+ "nn.dense" ,
863
+ "nn.batch_matmul" ,
864
+ ]
865
+ )
866
+ if isinstance (call .op , tvm .tir .op .Op ):
867
+ if str (call .op ) in compute_intensive_ops :
868
+ self .is_compute_intensive = True
869
+
870
+ return super ().visit_call (call )
871
+
872
+ def is_graph_compute_intensive (self , subgraph ) -> bool :
873
+ """
874
+ This function recursively visits the graph and checks if it's compute intensive"
875
+ """
876
+ self .visit (subgraph )
877
+ return self .is_compute_intensive
878
+
879
+
798
880
def is_valid_subgraph (params , body ):
799
881
"""Final check on whether the subgraph is valid and should be offloaded to TensorRT."""
800
882
# Remove invalid subgraphs for implicit batch mode.
801
883
if get_tensorrt_use_implicit_batch_mode ():
802
884
input_batch_sizes = []
803
885
for var in params :
804
886
# In implicit batch mode, all inputs must have same batch size
887
+ # TODO: (codeislife99) : Fix different dynamic batch size inputs
888
+
805
889
if isinstance (var .checked_type , relay .TupleType ):
806
890
for tupe_type in var .checked_type .fields :
807
891
# Scalar inputs not allowed
808
892
if len (tupe_type .shape ) == 0 :
809
893
logger .info ("tensorrt: scalar inputs not supported" )
810
894
return False
811
- input_batch_sizes .append (int (tupe_type .shape [0 ]))
895
+
896
+ if not isinstance (tupe_type .shape [0 ], tvm .tir .expr .Any ):
897
+ input_batch_sizes .append (int (tupe_type .shape [0 ]))
812
898
else :
813
899
# Scalar inputs not allowed
814
900
if len (var .checked_type .shape ) == 0 :
815
901
logger .info ("tensorrt: scalar inputs not supported" )
816
902
return False
817
- input_batch_sizes .append (int (var .checked_type .shape [0 ]))
903
+ if not isinstance (var .checked_type .shape [0 ], tvm .tir .expr .Any ):
904
+ input_batch_sizes .append (int (var .checked_type .shape [0 ]))
818
905
if len (input_batch_sizes ) > 1 and len (set (input_batch_sizes )) != 1 :
819
906
logger .info ("tensorrt: inputs have different batch sizes" )
820
907
return False
821
- # Remove subgraphs with no multiply-accumulates
822
- if get_tensorrt_remove_no_mac_subgraphs () and relay .analysis .get_total_mac_number (body ) == 0 :
908
+ if (
909
+ get_tensorrt_remove_no_mac_subgraphs ()
910
+ and not IsComputeIntensiveGraph ().is_graph_compute_intensive (body )
911
+ ):
823
912
return False
824
913
return True
825
914
@@ -880,6 +969,8 @@ class RemoveDropout(ExprMutator):
880
969
881
970
def visit_tuple_getitem (self , op ):
882
971
visit = super ().visit_tuple_getitem (op )
972
+ if visit .index != 0 :
973
+ return visit
883
974
if (
884
975
isinstance (visit .tuple_value , Call )
885
976
and visit .tuple_value .op .name == "nn.dropout"
0 commit comments