7
7
import numpy as np
8
8
import torch
9
9
from torch .fx .node import Argument , Node , Target
10
+ from torch_tensorrt .dynamo ._settings import CompilationSettings
10
11
from torch_tensorrt .dynamo ._SourceIR import SourceIR
11
12
from torch_tensorrt .dynamo .conversion import impl
12
13
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
@@ -48,7 +49,7 @@ def get_ir(target: Target) -> SourceIR:
48
49
return SourceIR .UNKNOWN
49
50
50
51
51
- def one_user_validator (node : Node ) -> bool :
52
+ def one_user_validator (node : Node , settings : CompilationSettings = None ) -> bool :
52
53
# Validate only one user, which is a getitem node that accesses the first element in the list
53
54
return (
54
55
len (node .users ) == 1
@@ -270,7 +271,11 @@ def aten_ops_embedding(
270
271
)
271
272
272
273
273
- def embedding_bag_validator (node : Node ) -> bool :
274
+ def embedding_bag_validator (node : Node , settings : CompilationSettings = None ) -> bool :
275
+ # Embedding bag op is not refitable
276
+ if settings .make_refittable :
277
+ return False
278
+
274
279
if not one_user_validator (node ):
275
280
return False
276
281
meta = node .args [1 ].meta
@@ -416,7 +421,7 @@ def aten_ops_symsize_int(
416
421
return impl .shape .shape (ctx , target , SourceIR .ATEN , name , args [0 ], args [1 ])
417
422
418
423
419
- def index_dtype_validator (node : Node ) -> bool :
424
+ def index_dtype_validator (node : Node , settings : CompilationSettings = None ) -> bool :
420
425
index = node .args [1 ]
421
426
for ind in index :
422
427
if ind is not None :
@@ -837,7 +842,7 @@ def aten_ops_select(
837
842
)
838
843
839
844
840
- def index_put_validator (node : Node ) -> bool :
845
+ def index_put_validator (node : Node , settings : CompilationSettings = None ) -> bool :
841
846
if args_bounds_check (node .args , 3 , False ): # Check if accumulate is valid
842
847
_LOGGER .debug ("We do not support accumulate=True for aten.index_put operation" )
843
848
accumulate_valid = False
@@ -924,7 +929,18 @@ def aten_ops_slice(
924
929
)
925
930
926
931
927
- @dynamo_tensorrt_converter (torch .ops .aten .cumsum .default , supports_dynamic_shapes = True )
932
+ def refit_validator (node : Node , settings : CompilationSettings = None ) -> bool :
933
+ # cumsum op is not refitable
934
+ if settings and settings .make_refittable :
935
+ return False
936
+ return True
937
+
938
+
939
+ @dynamo_tensorrt_converter (
940
+ torch .ops .aten .cumsum .default ,
941
+ capability_validator = refit_validator ,
942
+ supports_dynamic_shapes = True ,
943
+ )
928
944
@enforce_tensor_types (
929
945
{
930
946
0 : (TRTTensor ,),
@@ -970,7 +986,7 @@ def aten_ops_tile(
970
986
)
971
987
972
988
973
- def zero_output_validator (node : Node ) -> bool :
989
+ def zero_output_validator (node : Node , settings : CompilationSettings = None ) -> bool :
974
990
if 0 in node .args [1 ]:
975
991
_LOGGER .debug (
976
992
f"We do not support output tensor { node .args [1 ]} tensors with zero-sized dimensions for this operation."
@@ -1027,7 +1043,9 @@ def aten_ops_permute(
1027
1043
)
1028
1044
1029
1045
1030
- def to_copy_dtype_validator (placeholder_only : bool ) -> Callable [[Node ], bool ]:
1046
+ def to_copy_dtype_validator (
1047
+ placeholder_only : bool , settings : CompilationSettings = None
1048
+ ) -> Callable [[Node , CompilationSettings ], bool ]:
1031
1049
"""Return validator for to_copy node with placeholder restrictions"""
1032
1050
1033
1051
def validate_dtype (to_copy_node : Node ) -> bool :
@@ -1059,7 +1077,7 @@ def validate_dtype(to_copy_node: Node) -> bool:
1059
1077
)
1060
1078
return False
1061
1079
1062
- def validator (to_copy_node : Node ) -> bool :
1080
+ def validator (to_copy_node : Node , settings : CompilationSettings = None ) -> bool :
1063
1081
"""Returns true if the to_copy node can be converted to TRT
1064
1082
and the placeholder restriction is satisfied
1065
1083
"""
@@ -1074,7 +1092,9 @@ def validator(to_copy_node: Node) -> bool:
1074
1092
1075
1093
@dynamo_tensorrt_converter (
1076
1094
torch .ops .aten .clone .default ,
1077
- capability_validator = lambda node : not is_only_operator_on_placeholder (node ),
1095
+ capability_validator = lambda node , settings : not is_only_operator_on_placeholder (
1096
+ node , settings
1097
+ ),
1078
1098
supports_dynamic_shapes = True ,
1079
1099
)
1080
1100
@dynamo_tensorrt_converter (
@@ -2128,7 +2148,7 @@ def aten_ops_logical_xor(
2128
2148
)
2129
2149
2130
2150
2131
- def bitwise_type_validator (node : Node ) -> bool :
2151
+ def bitwise_type_validator (node : Node , settings : CompilationSettings = None ) -> bool :
2132
2152
supported_type = [torch .bool , bool ]
2133
2153
2134
2154
tensor_targets = [
@@ -2271,7 +2291,9 @@ def aten_ops_bitwise_xor(
2271
2291
)
2272
2292
2273
2293
2274
- def bitwise_not_type_validator (node : Node ) -> bool :
2294
+ def bitwise_not_type_validator (
2295
+ node : Node , settings : CompilationSettings = None
2296
+ ) -> bool :
2275
2297
val = node .args [0 ]
2276
2298
val_meta = val .meta .get ("tensor_meta" )
2277
2299
@@ -2453,7 +2475,7 @@ def aten_ops_le(
2453
2475
)
2454
2476
2455
2477
2456
- def conv_param_validator (conv_node : Node ) -> bool :
2478
+ def conv_param_validator (conv_node : Node , settings : CompilationSettings = None ) -> bool :
2457
2479
return conv_node .args [7 ] in ([0 ], [0 , 0 ], [0 , 0 , 0 ])
2458
2480
2459
2481
@@ -2549,7 +2571,9 @@ def aten_ops_cdist_forward(
2549
2571
)
2550
2572
2551
2573
2552
- def avg_pool_param_validator (pool_node : Node ) -> bool :
2574
+ def avg_pool_param_validator (
2575
+ pool_node : Node , settings : CompilationSettings = None
2576
+ ) -> bool :
2553
2577
ceil_mode = args_bounds_check (pool_node .args , 4 , False )
2554
2578
divisor_override = args_bounds_check (pool_node .args , 6 )
2555
2579
@@ -2665,12 +2689,12 @@ def aten_ops_adaptive_avg_poolNd(
2665
2689
)
2666
2690
2667
2691
2668
- def topk_validator (node : Node ) -> bool :
2692
+ def topk_validator (node : Node , settings : CompilationSettings = None ) -> bool :
2669
2693
k = node .args [1 ]
2670
2694
return topk_sort_validator (k )
2671
2695
2672
2696
2673
- def sort_validator (node : Node ) -> bool :
2697
+ def sort_validator (node : Node , settings : CompilationSettings = None ) -> bool :
2674
2698
meta_data = node .args [0 ].meta .get ("tensor_meta" )
2675
2699
if meta_data is None :
2676
2700
return False
@@ -2692,7 +2716,9 @@ def topk_sort_validator(k: int) -> bool:
2692
2716
return True
2693
2717
2694
2718
2695
- def max_pool_param_validator (pool_node : Node ) -> bool :
2719
+ def max_pool_param_validator (
2720
+ pool_node : Node , settings : CompilationSettings = None
2721
+ ) -> bool :
2696
2722
dilation = args_bounds_check (pool_node .args , 4 , 1 )
2697
2723
ceil_mode = args_bounds_check (pool_node .args , 5 , False )
2698
2724
@@ -2746,7 +2772,7 @@ def aten_ops_max_pool(
2746
2772
)
2747
2773
2748
2774
2749
- def attention_validator (node : Node ) -> bool :
2775
+ def attention_validator (node : Node , settings : CompilationSettings = None ) -> bool :
2750
2776
# Currently, `attn_mask` is not supported
2751
2777
return args_bounds_check (node .args , 3 ) is None
2752
2778
@@ -3637,7 +3663,7 @@ def aten_ops_flip(
3637
3663
)
3638
3664
3639
3665
3640
- def zero_diag_size_validator (node : Node ) -> bool :
3666
+ def zero_diag_size_validator (node : Node , settings : CompilationSettings = None ) -> bool :
3641
3667
meta = node .args [0 ].meta .get ("tensor_meta" )
3642
3668
if meta :
3643
3669
input_shape = meta .shape
@@ -3765,7 +3791,9 @@ def aten_ops_index_select(
3765
3791
)
3766
3792
3767
3793
3768
- def dropout_inference_validator (node : Node ) -> bool :
3794
+ def dropout_inference_validator (
3795
+ node : Node , settings : CompilationSettings = None
3796
+ ) -> bool :
3769
3797
train_mode = args_bounds_check (node .args , 2 , None )
3770
3798
if train_mode is False :
3771
3799
return True
0 commit comments