File tree Expand file tree Collapse file tree 2 files changed +16
-5
lines changed
python/tvm/relay/frontend
tests/python/frontend/pytorch Expand file tree Collapse file tree 2 files changed +16
-5
lines changed Original file line number Diff line number Diff line change @@ -535,6 +535,9 @@ def _impl(inputs, input_types):
535
535
536
536
def _where ():
537
537
def _impl (inputs , input_types ):
538
+ if len (inputs ) == 1 :
539
+ return _nonzero (False )([inputs [0 ], True ], input_types )
540
+
538
541
cond = inputs [0 ]
539
542
x , y = _pytorch_promote_types (inputs [1 :3 ], input_types [1 :3 ])
540
543
return _op .where (cond , x , y )
@@ -2278,9 +2281,8 @@ def _impl(inputs, input_types):
2278
2281
ret = _op .transform .argwhere (data )
2279
2282
2280
2283
if is_numpy_style or (len (inputs ) > 1 and inputs [1 ]):
2281
- # TODO(kevinthesun): Support this by adding unbind op
2282
- # ret = _unbind()([ret, 0], None)
2283
- raise RuntimeError ("as_tuple is not supported yet for nonzero." )
2284
+ return _unbind ()([ret , 1 ], None )
2285
+
2284
2286
return ret
2285
2287
2286
2288
return _impl
Original file line number Diff line number Diff line change @@ -2865,10 +2865,19 @@ class Where2(Module):
2865
2865
def forward (self , * args ):
2866
2866
return torch .where (args [0 ] > 0 , args [0 ], args [1 ])
2867
2867
2868
+ class Where3 (Module ):
2869
+ def forward (self , * args ):
2870
+ return torch .where (args [0 ])[0 ]
2871
+
2868
2872
x = torch .rand ([3 , 2 ]).float ()
2869
- verify_model (Where1 (). float (). eval () , input_data = [x ])
2873
+ verify_model (Where1 (), input_data = [x ])
2870
2874
y = torch .rand ([3 , 2 ])
2871
- verify_model (Where2 ().float ().eval (), input_data = [x , y ])
2875
+ verify_model (Where2 (), input_data = [x , y ])
2876
+
2877
+ # a single argument variant, equivalent to torch.nonzero(..., as_tuple=True)
2878
+ inp = torch .rand ([10 ])
2879
+ inp [3 :8 ] = 0
2880
+ verify_trace_model (Where3 (), [inp ], ["llvm" ])
2872
2881
2873
2882
2874
2883
@tvm .testing .uses_gpu
You can’t perform that action at this time.
0 commit comments