Skip to content

Commit 017334a

Browse files
masamasahi
authored andcommitted
also supported the latest master (1.7)
1 parent 8d9dd2a commit 017334a

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,9 @@ def _impl(inputs, input_types):
535535

536536
def _where():
537537
def _impl(inputs, input_types):
538+
if len(inputs) == 1:
539+
return _nonzero(False)([inputs[0], True], input_types)
540+
538541
cond = inputs[0]
539542
x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3])
540543
return _op.where(cond, x, y)
@@ -2278,9 +2281,8 @@ def _impl(inputs, input_types):
22782281
ret = _op.transform.argwhere(data)
22792282

22802283
if is_numpy_style or (len(inputs) > 1 and inputs[1]):
2281-
# TODO(kevinthesun): Support this by adding unbind op
2282-
# ret = _unbind()([ret, 0], None)
2283-
raise RuntimeError("as_tuple is not supported yet for nonzero.")
2284+
return _unbind()([ret, 1], None)
2285+
22842286
return ret
22852287

22862288
return _impl

tests/python/frontend/pytorch/test_forward.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2865,10 +2865,19 @@ class Where2(Module):
28652865
def forward(self, *args):
28662866
return torch.where(args[0] > 0, args[0], args[1])
28672867

2868+
class Where3(Module):
2869+
def forward(self, *args):
2870+
return torch.where(args[0])[0]
2871+
28682872
x = torch.rand([3, 2]).float()
2869-
verify_model(Where1().float().eval(), input_data=[x])
2873+
verify_model(Where1(), input_data=[x])
28702874
y = torch.rand([3, 2])
2871-
verify_model(Where2().float().eval(), input_data=[x, y])
2875+
verify_model(Where2(), input_data=[x, y])
2876+
2877+
# a single argument variant, equivalent to torch.nonzero(..., as_tuple=True)
2878+
inp = torch.rand([10])
2879+
inp[3:8] = 0
2880+
verify_trace_model(Where3(), [inp], ["llvm"])
28722881

28732882

28742883
@tvm.testing.uses_gpu

0 commit comments

Comments
 (0)