Skip to content

Commit fbb130f

Browse files
masahimasa
authored andcommitted
[Torch] Object detection support update for PyTorch 1.6 (apache#6659)
* update split * fix * cast nms output to int64 * add more comment and numel test * fix lint * also supported the latest master (1.7) Co-authored-by: masa <masa@pop-os.localdomain>
1 parent fa65713 commit fbb130f

File tree

3 files changed

+75
-13
lines changed

3 files changed

+75
-13
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@
4646
__all__ = ["from_pytorch"]
4747

4848

49+
def _is_version_greater_than(ver):
50+
import torch
51+
from packaging import version
52+
53+
# Torch version > 1.4 changed upsampling API
54+
return version.parse(torch.__version__) > version.parse(ver)
55+
56+
4957
# List ADT utilities
5058
def _infer_type_with_prelude(val, prelude):
5159
body = _infer_type(val, prelude.mod)
@@ -413,13 +421,18 @@ def _impl(inputs, input_types):
413421
def _split_with_sizes():
414422
def _impl(inputs, input_types):
415423
data = inputs[0]
424+
sections = inputs[1]
416425
dim = int(inputs[2])
417426

427+
if len(sections) == 1:
428+
# a special case used in torchvision detection models
429+
return _expr.TupleWrapper(_expr.Tuple([data]), 1)
430+
418431
split_index = 0
419432
indices = []
420-
sections = inputs[1]
421433
for i in range(len(sections) - 1):
422-
split_index += sections[i]
434+
index, _ = try_infer_value(sections[i], lambda ret: int(ret))
435+
split_index += index
423436
indices.append(split_index)
424437

425438
return _op.split(data, indices, dim)
@@ -522,6 +535,9 @@ def _impl(inputs, input_types):
522535

523536
def _where():
524537
def _impl(inputs, input_types):
538+
if len(inputs) == 1:
539+
return _nonzero(False)([inputs[0], True], input_types)
540+
525541
cond = inputs[0]
526542
x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3])
527543
return _op.where(cond, x, y)
@@ -1865,11 +1881,8 @@ def func(x):
18651881
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
18661882

18671883
if _is_quantized_tensor(data, prelude):
1868-
import torch
1869-
from packaging import version
1870-
18711884
# Torch version > 1.4 changed upsampling API
1872-
if version.parse(torch.__version__) > version.parse("1.4.0"):
1885+
if _is_version_greater_than("1.4.0"):
18731886
num_inputs = 7
18741887
else:
18751888
num_inputs = 5
@@ -2172,9 +2185,11 @@ def _impl(inputs, input_types):
21722185
data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])
21732186

21742187
# strided slice to get the dynamic result
2175-
return get_relay_op("strided_slice")(
2188+
ret = get_relay_op("strided_slice")(
21762189
data_slice, begin=_expr.const([0]), end=size, slice_mode="size"
21772190
)
2191+
# in torchvision, indices from nms are int64
2192+
return _op.cast(ret, "int64")
21782193

21792194
return _impl
21802195

@@ -2266,9 +2281,8 @@ def _impl(inputs, input_types):
22662281
ret = _op.transform.argwhere(data)
22672282

22682283
if is_numpy_style or (len(inputs) > 1 and inputs[1]):
2269-
# TODO(kevinthesun): Support this by adding unbind op
2270-
# ret = _unbind()([ret, 0], None)
2271-
raise RuntimeError("as_tuple is not supported yet for nonzero.")
2284+
return _unbind()([ret, 1], None)
2285+
22722286
return ret
22732287

22742288
return _impl
@@ -2335,6 +2349,21 @@ def _impl(inputs, input_types):
23352349
return _impl
23362350

23372351

2352+
def _numel():
2353+
def _impl(inputs, input_types):
2354+
return _op.ndarray_size(inputs[0])
2355+
2356+
return _impl
2357+
2358+
2359+
def _empty():
2360+
def _impl(inputs, input_types):
2361+
shape = inputs[0]
2362+
return _op.zeros(shape, _convert_dtype_value(inputs[1]))
2363+
2364+
return _impl
2365+
2366+
23382367
def _pytorch_result_type(dtypes, non_tensor_inputs):
23392368
"""This promotes TVM dtypes like PyTorch would"""
23402369
import torch
@@ -2673,6 +2702,10 @@ def _get_convert_map(prelude, default_dtype):
26732702
"aten::scatter": _scatter(),
26742703
"aten::scalar_tensor": _scalar_tensor(),
26752704
"aten::__interpolate": _interpolate(),
2705+
"aten::IntImplicit": _identity(),
2706+
"aten::tensor": _identity(), # used for example in tensor(1.0)
2707+
"aten::numel": _numel(),
2708+
"aten::empty": _empty(),
26762709
}
26772710
return convert_map
26782711

@@ -2681,7 +2714,13 @@ def _run_jit_passes(graph):
26812714
""" The inline pass is necessary to unwrap prim::CallMethod """
26822715
import torch
26832716

2684-
torch._C._jit_pass_inline(graph)
2717+
if _is_version_greater_than("1.5.0"):
2718+
# This is required for torchvision detection models from 1.6 above
2719+
# It is the same as _jit_pass_inline, except that it has some special
2720+
# case behaviors for some ops such as aten::__interpolate()
2721+
torch._C._jit_pass_onnx_function_substitution(graph)
2722+
else:
2723+
torch._C._jit_pass_inline(graph)
26852724

26862725

26872726
def _get_tensor_and_var(torch_tensor, name):

python/tvm/relay/op/_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def no_data_full_shape_func(attrs, inputs, out_ndims):
179179
"""
180180
Shape func for zeros and ones.
181181
"""
182+
if len(inputs) == 0:
183+
return [_convert_shape(convert(attrs.shape))]
182184
return [_full_shape_func(inputs[0])]
183185

184186

tests/python/frontend/pytorch/test_forward.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2865,10 +2865,19 @@ class Where2(Module):
28652865
def forward(self, *args):
28662866
return torch.where(args[0] > 0, args[0], args[1])
28672867

2868+
class Where3(Module):
2869+
def forward(self, *args):
2870+
return torch.where(args[0])[0]
2871+
28682872
x = torch.rand([3, 2]).float()
2869-
verify_model(Where1().float().eval(), input_data=[x])
2873+
verify_model(Where1(), input_data=[x])
28702874
y = torch.rand([3, 2])
2871-
verify_model(Where2().float().eval(), input_data=[x, y])
2875+
verify_model(Where2(), input_data=[x, y])
2876+
2877+
# a single argument variant, equivalent to torch.nonzero(..., as_tuple=True)
2878+
inp = torch.rand([10])
2879+
inp[3:8] = 0
2880+
verify_trace_model(Where3(), [inp], ["llvm"])
28722881

28732882

28742883
@tvm.testing.uses_gpu
@@ -3152,6 +3161,17 @@ def forward(self, data, index, src):
31523161
verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"])
31533162

31543163

3164+
def test_numel():
3165+
class Numel(Module):
3166+
def forward(self, data):
3167+
return torch.tensor(torch.numel(data))
3168+
3169+
targets = _get_default_vm_targets()
3170+
verify_script_model(Numel(), [(1,)], targets)
3171+
verify_script_model(Numel(), [(3, 5)], targets)
3172+
verify_script_model(Numel(), [(3, 5, 8)], targets)
3173+
3174+
31553175
def test_forward_pretrained_bert_base_uncased():
31563176
######################################################################
31573177
# This is an example how to run BERT models using TVM
@@ -3455,6 +3475,7 @@ def expected(x_shape, y_shape):
34553475
test_forward_unbind()
34563476
test_forward_nonzero()
34573477
test_forward_scatter()
3478+
test_numel()
34583479

34593480
# Model tests
34603481
test_resnet18()

0 commit comments

Comments
 (0)