Skip to content

Commit b1364eb

Browse files
authored
[PYTORCH]Take, Topk op support (#5332)
* [PYTORCH]take, topk op support * Ci Failure fix
1 parent afcf939 commit b1364eb

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,39 @@ def _impl(inputs, input_types):
272272
return _op.transform.take(data, index, axis=dim)
273273
return _impl
274274

275+
def _take():
276+
def _impl(inputs, input_types):
277+
data = inputs[0]
278+
import torch
279+
280+
if isinstance(inputs[1], _expr.Var):
281+
indices = _op.cast(inputs[1], "int32")
282+
elif isinstance(inputs[1], torch.Tensor):
283+
indices = _wrap_const(inputs[1].numpy())
284+
else:
285+
msg = "Data type %s could not be parsed in take operator." % (type(inputs[1]))
286+
raise AssertionError(msg)
287+
288+
return _op.transform.take(data, indices=indices)
289+
return _impl
290+
291+
def _topk():
292+
def _impl(inputs, input_types):
293+
data = inputs[0]
294+
k = int(inputs[1])
295+
axis = int(inputs[2])
296+
is_ascend = not bool(inputs[3])
297+
sort = bool(inputs[4])
298+
299+
if not sort:
300+
msg = "Currently supports only sorted output for topk operator."
301+
raise AssertionError(msg)
302+
303+
outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both")
304+
305+
return outs[0], outs[1]
306+
return _impl
307+
275308
def _reciprocal():
276309
def _impl(inputs, input_types):
277310
data = inputs[0]
@@ -1416,6 +1449,8 @@ def _get_convert_map(prelude):
14161449
"aten::split" : _split(),
14171450
"aten::split_with_sizes" : _split_with_sizes(),
14181451
"aten::select" : _select(),
1452+
"aten::take" : _take(),
1453+
"aten::topk" : _topk(),
14191454
"aten::relu" : _relu(),
14201455
"aten::relu_" : _relu(),
14211456
"aten::prelu" : _prelu(),

tests/python/frontend/pytorch/test_forward.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,61 @@ def forward(self, *args):
15451545
verify_model(Round1().float().eval(), input_data=input_data)
15461546

15471547

1548+
def test_forward_take():
1549+
torch.set_grad_enabled(False)
1550+
class Take1(Module):
1551+
def forward(self, *args):
1552+
indices = torch.tensor([[0,0],[1,0]])
1553+
if torch.cuda.is_available():
1554+
indices = indices.cuda()
1555+
return torch.take(args[0], indices)
1556+
1557+
class Take2(Module):
1558+
def forward(self, *args):
1559+
return torch.take(args[0], args[1])
1560+
1561+
input_data = torch.tensor([[1,2],[3,4]])
1562+
verify_model(Take1().float().eval(), input_data=input_data)
1563+
indices = torch.tensor([[0,0],[1,0]])
1564+
verify_model(Take2().float().eval(), input_data=[input_data, indices])
1565+
1566+
1567+
def test_forward_topk():
1568+
torch.set_grad_enabled(False)
1569+
class Topk1(Module):
1570+
def forward(self, *args):
1571+
return torch.topk(args[0], k=3)
1572+
1573+
class Topk2(Module):
1574+
def forward(self, *args):
1575+
return torch.topk(args[0], k=3, dim=-2)
1576+
1577+
class Topk3(Module):
1578+
def forward(self, *args):
1579+
return torch.topk(args[0], k=3, dim=3)
1580+
1581+
class Topk4(Module):
1582+
def forward(self, *args):
1583+
return torch.topk(args[0], k=3, largest=True)
1584+
1585+
class Topk5(Module):
1586+
def forward(self, *args):
1587+
return torch.topk(args[0], k=3, largest=False)
1588+
1589+
class Topk6(Module):
1590+
def forward(self, *args):
1591+
return torch.topk(args[0], k=3, sorted=True)
1592+
1593+
input_shape = [1, 3, 10, 10]
1594+
input_data = torch.rand(input_shape).float()
1595+
verify_model(Topk1().float().eval(), input_data=input_data)
1596+
verify_model(Topk2().float().eval(), input_data=input_data)
1597+
verify_model(Topk3().float().eval(), input_data=input_data)
1598+
verify_model(Topk4().float().eval(), input_data=input_data)
1599+
verify_model(Topk5().float().eval(), input_data=input_data)
1600+
verify_model(Topk6().float().eval(), input_data=input_data)
1601+
1602+
15481603
if __name__ == "__main__":
15491604
# Single operator tests
15501605
test_forward_add()
@@ -1587,6 +1642,8 @@ def forward(self, *args):
15871642
test_forward_size()
15881643
test_forward_view()
15891644
test_forward_select()
1645+
test_forward_take()
1646+
test_forward_topk()
15901647
test_forward_clone()
15911648
test_forward_softplus()
15921649
test_forward_softsign()

0 commit comments

Comments
 (0)