46
46
__all__ = ["from_pytorch" ]
47
47
48
48
49
+ def _is_version_greater_than (ver ):
50
+ import torch
51
+ from packaging import version
52
+
53
+ # Torch version > 1.4 changed upsampling API
54
+ return version .parse (torch .__version__ ) > version .parse (ver )
55
+
56
+
49
57
# List ADT utilities
50
58
def _infer_type_with_prelude (val , prelude ):
51
59
body = _infer_type (val , prelude .mod )
@@ -413,13 +421,18 @@ def _impl(inputs, input_types):
413
421
def _split_with_sizes ():
414
422
def _impl (inputs , input_types ):
415
423
data = inputs [0 ]
424
+ sections = inputs [1 ]
416
425
dim = int (inputs [2 ])
417
426
427
+ if len (sections ) == 1 :
428
+ # a special case used in torchvision detection models
429
+ return _expr .TupleWrapper (_expr .Tuple ([data ]), 1 )
430
+
418
431
split_index = 0
419
432
indices = []
420
- sections = inputs [1 ]
421
433
for i in range (len (sections ) - 1 ):
422
- split_index += sections [i ]
434
+ index , _ = try_infer_value (sections [i ], lambda ret : int (ret ))
435
+ split_index += index
423
436
indices .append (split_index )
424
437
425
438
return _op .split (data , indices , dim )
@@ -522,6 +535,9 @@ def _impl(inputs, input_types):
522
535
523
536
def _where ():
524
537
def _impl (inputs , input_types ):
538
+ if len (inputs ) == 1 :
539
+ return _nonzero (False )([inputs [0 ], True ], input_types )
540
+
525
541
cond = inputs [0 ]
526
542
x , y = _pytorch_promote_types (inputs [1 :3 ], input_types [1 :3 ])
527
543
return _op .where (cond , x , y )
@@ -1865,11 +1881,8 @@ def func(x):
1865
1881
return _op .image .resize (x , out_size , "NCHW" , method , coord_trans )
1866
1882
1867
1883
if _is_quantized_tensor (data , prelude ):
1868
- import torch
1869
- from packaging import version
1870
-
1871
1884
# Torch version > 1.4 changed upsampling API
1872
- if version . parse ( torch . __version__ ) > version . parse ("1.4.0" ):
1885
+ if _is_version_greater_than ("1.4.0" ):
1873
1886
num_inputs = 7
1874
1887
else :
1875
1888
num_inputs = 5
@@ -2172,9 +2185,11 @@ def _impl(inputs, input_types):
2172
2185
data_slice = get_relay_op ("squeeze" )(nms_ret [0 ], axis = [0 ])
2173
2186
2174
2187
# strided slice to get the dynamic result
2175
- return get_relay_op ("strided_slice" )(
2188
+ ret = get_relay_op ("strided_slice" )(
2176
2189
data_slice , begin = _expr .const ([0 ]), end = size , slice_mode = "size"
2177
2190
)
2191
+ # in torchvision, indices from nms are int64
2192
+ return _op .cast (ret , "int64" )
2178
2193
2179
2194
return _impl
2180
2195
@@ -2266,9 +2281,8 @@ def _impl(inputs, input_types):
2266
2281
ret = _op .transform .argwhere (data )
2267
2282
2268
2283
if is_numpy_style or (len (inputs ) > 1 and inputs [1 ]):
2269
- # TODO(kevinthesun): Support this by adding unbind op
2270
- # ret = _unbind()([ret, 0], None)
2271
- raise RuntimeError ("as_tuple is not supported yet for nonzero." )
2284
+ return _unbind ()([ret , 1 ], None )
2285
+
2272
2286
return ret
2273
2287
2274
2288
return _impl
@@ -2335,6 +2349,21 @@ def _impl(inputs, input_types):
2335
2349
return _impl
2336
2350
2337
2351
2352
+ def _numel ():
2353
+ def _impl (inputs , input_types ):
2354
+ return _op .ndarray_size (inputs [0 ])
2355
+
2356
+ return _impl
2357
+
2358
+
2359
+ def _empty ():
2360
+ def _impl (inputs , input_types ):
2361
+ shape = inputs [0 ]
2362
+ return _op .zeros (shape , _convert_dtype_value (inputs [1 ]))
2363
+
2364
+ return _impl
2365
+
2366
+
2338
2367
def _pytorch_result_type (dtypes , non_tensor_inputs ):
2339
2368
"""This promotes TVM dtypes like PyTorch would"""
2340
2369
import torch
@@ -2673,6 +2702,10 @@ def _get_convert_map(prelude, default_dtype):
2673
2702
"aten::scatter" : _scatter (),
2674
2703
"aten::scalar_tensor" : _scalar_tensor (),
2675
2704
"aten::__interpolate" : _interpolate (),
2705
+ "aten::IntImplicit" : _identity (),
2706
+ "aten::tensor" : _identity (), # used for example in tensor(1.0)
2707
+ "aten::numel" : _numel (),
2708
+ "aten::empty" : _empty (),
2676
2709
}
2677
2710
return convert_map
2678
2711
@@ -2681,7 +2714,13 @@ def _run_jit_passes(graph):
2681
2714
""" The inline pass is necessary to unwrap prim::CallMethod """
2682
2715
import torch
2683
2716
2684
- torch ._C ._jit_pass_inline (graph )
2717
+ if _is_version_greater_than ("1.5.0" ):
2718
+ # This is required for torchvision detection models from 1.6 above
2719
+ # It is the same as _jit_pass_inline, except that it has some special
2720
+ # case behaviors for some ops such as aten::__interpolate()
2721
+ torch ._C ._jit_pass_onnx_function_substitution (graph )
2722
+ else :
2723
+ torch ._C ._jit_pass_inline (graph )
2685
2724
2686
2725
2687
2726
def _get_tensor_and_var (torch_tensor , name ):
0 commit comments