Skip to content

Commit 5266afc

Browse files
siju-samueltrevor-m
authored andcommitted
[PYTORCH]Unary Ops (apache#5378)
1 parent 8c8b5b6 commit 5266afc

File tree

2 files changed

+114
-123
lines changed

2 files changed

+114
-123
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 26 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,16 @@ def _impl(inputs, input_types):
132132
return get_relay_op(name)(data0, data1)
133133
return _impl
134134

135-
def _abs():
135+
136+
def _unary(name):
136137
def _impl(inputs, input_types):
137-
data = inputs[0]
138-
return _op.abs(data)
138+
input_type = input_types[0]
139+
data = _convert_elemwise_input(inputs[0], input_type)
140+
141+
return get_relay_op(name)(data)
139142
return _impl
140143

144+
141145
def _arange():
142146
def _impl(inputs, input_types):
143147
if len(inputs) == 5:
@@ -1260,26 +1264,6 @@ def _impl(inputs, input_types):
12601264
return _op.nn.pad(data, pad_width, pad_value)
12611265
return _impl
12621266

1263-
def _sqrt():
1264-
def _impl(inputs, input_types):
1265-
data = inputs[0]
1266-
return _op.tensor.sqrt(data)
1267-
return _impl
1268-
1269-
1270-
def _rsqrt():
1271-
def _impl(inputs, input_types):
1272-
data = inputs[0]
1273-
return _op.tensor.rsqrt(data)
1274-
return _impl
1275-
1276-
1277-
def _ceil():
1278-
def _impl(inputs, input_types):
1279-
data = inputs[0]
1280-
return _op.ceil(data)
1281-
return _impl
1282-
12831267

12841268
def _clamp():
12851269
def _impl(inputs, input_types):
@@ -1290,20 +1274,6 @@ def _impl(inputs, input_types):
12901274
return _impl
12911275

12921276

1293-
def _floor():
1294-
def _impl(inputs, input_types):
1295-
data = inputs[0]
1296-
return _op.floor(data)
1297-
return _impl
1298-
1299-
1300-
def _round():
1301-
def _impl(inputs, input_types):
1302-
data = inputs[0]
1303-
return _op.round(data)
1304-
return _impl
1305-
1306-
13071277
def _to():
13081278
def _impl(inputs, input_types):
13091279
data = inputs[0]
@@ -1381,17 +1351,6 @@ def _impl(inputs, input_types):
13811351
return inputs[0]
13821352
return _impl
13831353

1384-
def _neg():
1385-
def _impl(inputs, input_types):
1386-
data = inputs[0]
1387-
return _op.tensor.negative(data)
1388-
return _impl
1389-
1390-
def _tanh():
1391-
def _impl(inputs, input_types):
1392-
data = inputs[0]
1393-
return _op.tensor.tanh(data)
1394-
return _impl
13951354

13961355
def _Bool():
13971356
def _impl(inputs, input_types):
@@ -1473,18 +1432,6 @@ def _impl(inputs, input_types):
14731432
return _impl
14741433

14751434

1476-
def _isfinite():
1477-
def _impl(inputs, input_types):
1478-
return _op.isfinite(inputs[0])
1479-
return _impl
1480-
1481-
1482-
def _isnan():
1483-
def _impl(inputs, input_types):
1484-
return _op.isnan(inputs[0])
1485-
return _impl
1486-
1487-
14881435
def _list_getitem(prelude):
14891436
def _impl(inputs, input_types):
14901437
return prelude.nth(inputs[0], _wrap_const(inputs[1]))
@@ -1607,7 +1554,6 @@ def _get_convert_map(prelude):
16071554
"aten::mul" : _elemwise("multiply"),
16081555
"aten::mul_" : _elemwise("multiply"),
16091556
"aten::pow" : _elemwise("power"),
1610-
"aten::abs" : _abs(),
16111557
"aten::arange" : _arange(),
16121558
"aten::div" : _elemwise("divide"),
16131559
"aten::div_" : _elemwise("divide"),
@@ -1689,12 +1635,26 @@ def _get_convert_map(prelude):
16891635
"aten::argmax" : _reduce("argmax"),
16901636
"aten::std" : _std(),
16911637
"aten::var" : _variance(),
1692-
"aten::sqrt" : _sqrt(),
1693-
"aten::rsqrt" : _rsqrt(),
1694-
"aten::ceil" : _ceil(),
1638+
"aten::abs" : _unary("abs"),
1639+
"aten::neg" : _unary("negative"),
1640+
"aten::cos" : _unary("cos"),
1641+
"aten::sin" : _unary("sin"),
1642+
"aten::tan" : _unary("tan"),
1643+
"aten::tanh" : _unary("tanh"),
1644+
"aten::atan" : _unary("atan"),
1645+
"aten::log" : _unary("log"),
1646+
"aten::exp" : _unary("exp"),
1647+
"aten::erf" : _unary("erf"),
1648+
"aten::trunc" : _unary("trunc"),
1649+
"aten::sign" : _unary("sign"),
1650+
"aten::sqrt" : _unary("sqrt"),
1651+
"aten::rsqrt" : _unary("rsqrt"),
1652+
"aten::ceil" : _unary("ceil"),
1653+
"aten::floor" : _unary("floor"),
1654+
"aten::round" : _unary("round"),
1655+
"aten::isfinite" : _unary("isfinite"),
1656+
"aten::isnan" : _unary("isnan"),
16951657
"aten::clamp" : _clamp(),
1696-
"aten::floor" : _floor(),
1697-
"aten::round" : _round(),
16981658
"aten::detach" : _identity(),
16991659
"aten::upsample_bilinear2d" : _upsample("bilinear"),
17001660
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
@@ -1709,12 +1669,8 @@ def _get_convert_map(prelude):
17091669
"aten::logical_xor" : _logical_xor(),
17101670
"aten::bitwise_not" : _bitwise_not(),
17111671
"aten::bitwise_xor" : _bitwise_xor(),
1712-
"aten::isfinite" : _isfinite(),
1713-
"aten::isnan" : _isnan(),
17141672
"aten::Bool" : _Bool(),
17151673
"aten::Float" : _Float(),
1716-
"aten::neg" : _neg(),
1717-
"aten::tanh" : _tanh(),
17181674
"aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
17191675
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(),
17201676
"aten::mm" : _matmul(),

tests/python/frontend/pytorch/test_forward.py

Lines changed: 88 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,30 +1508,6 @@ def forward(self, *args):
15081508
verify_model(IsInf1().float().eval(), input_data=input_data)
15091509

15101510

1511-
def test_forward_rsqrt():
1512-
torch.set_grad_enabled(False)
1513-
input_shape = [1, 3, 10, 10]
1514-
1515-
class Rsqrt1(Module):
1516-
def forward(self, *args):
1517-
return torch.rsqrt(args[0])
1518-
1519-
input_data = torch.rand(input_shape).float()
1520-
verify_model(Rsqrt1().float().eval(), input_data=input_data)
1521-
1522-
1523-
def test_forward_ceil():
1524-
torch.set_grad_enabled(False)
1525-
input_shape = [1, 3, 10, 10]
1526-
1527-
class Ceil1(Module):
1528-
def forward(self, *args):
1529-
return torch.ceil(args[0])
1530-
1531-
input_data = torch.rand(input_shape).float()
1532-
verify_model(Ceil1().float().eval(), input_data=input_data)
1533-
1534-
15351511
def test_forward_clamp():
15361512
torch.set_grad_enabled(False)
15371513
input_shape = [1, 3, 10, 10]
@@ -1554,30 +1530,6 @@ def forward(self, *args):
15541530
verify_model(Clamp3().float().eval(), input_data=input_data)
15551531

15561532

1557-
def test_forward_floor():
1558-
torch.set_grad_enabled(False)
1559-
input_shape = [1, 3, 10, 10]
1560-
1561-
class Floor1(Module):
1562-
def forward(self, *args):
1563-
return torch.floor(args[0])
1564-
1565-
input_data = torch.rand(input_shape).float()
1566-
verify_model(Floor1().float().eval(), input_data=input_data)
1567-
1568-
1569-
def test_forward_round():
1570-
torch.set_grad_enabled(False)
1571-
input_shape = [1, 3, 10, 10]
1572-
1573-
class Round1(Module):
1574-
def forward(self, *args):
1575-
return torch.round(args[0])
1576-
1577-
input_data = torch.rand(input_shape).float()
1578-
verify_model(Round1().float().eval(), input_data=input_data)
1579-
1580-
15811533
def test_forward_ones():
15821534
torch.set_grad_enabled(False)
15831535

@@ -1860,6 +1812,93 @@ def forward(self, *args):
18601812
verify_model(LogicalXor2().float().eval(), input_data=[lhs])
18611813

18621814

1815+
def test_forward_unary():
1816+
torch.set_grad_enabled(False)
1817+
1818+
class Sqrt1(Module):
1819+
def forward(self, *args):
1820+
return torch.sqrt(args[0])
1821+
1822+
class RSqrt1(Module):
1823+
def forward(self, *args):
1824+
return torch.rsqrt(args[0])
1825+
1826+
class Ceil1(Module):
1827+
def forward(self, *args):
1828+
return torch.ceil(args[0])
1829+
1830+
class Floor1(Module):
1831+
def forward(self, *args):
1832+
return torch.floor(args[0])
1833+
1834+
class Round1(Module):
1835+
def forward(self, *args):
1836+
return torch.round(args[0])
1837+
1838+
class Cos1(Module):
1839+
def forward(self, *args):
1840+
return torch.cos(args[0])
1841+
1842+
class Sin1(Module):
1843+
def forward(self, *args):
1844+
return torch.sin(args[0])
1845+
1846+
class Tan1(Module):
1847+
def forward(self, *args):
1848+
return torch.tan(args[0])
1849+
1850+
class Tanh1(Module):
1851+
def forward(self, *args):
1852+
return torch.tanh(args[0])
1853+
1854+
class ATanh1(Module):
1855+
def forward(self, *args):
1856+
return torch.atan(args[0])
1857+
1858+
class Log1(Module):
1859+
def forward(self, *args):
1860+
return torch.log(args[0])
1861+
1862+
class Exp1(Module):
1863+
def forward(self, *args):
1864+
return torch.exp(args[0])
1865+
1866+
class Erf1(Module):
1867+
def forward(self, *args):
1868+
return torch.erf(args[0])
1869+
1870+
class Trunc1(Module):
1871+
def forward(self, *args):
1872+
return torch.trunc(args[0])
1873+
1874+
class Sign1(Module):
1875+
def forward(self, *args):
1876+
return torch.sign(args[0])
1877+
1878+
class Neg1(Module):
1879+
def forward(self, *args):
1880+
return torch.neg(args[0])
1881+
1882+
input_shape = [1, 3, 10, 10]
1883+
input_data = torch.rand(input_shape).float()
1884+
verify_model(Sqrt1().float().eval(), input_data=input_data)
1885+
verify_model(RSqrt1().float().eval(), input_data=input_data)
1886+
verify_model(Ceil1().float().eval(), input_data=input_data)
1887+
verify_model(Floor1().float().eval(), input_data=input_data)
1888+
verify_model(Round1().float().eval(), input_data=input_data)
1889+
verify_model(Cos1().float().eval(), input_data=input_data)
1890+
verify_model(Sin1().float().eval(), input_data=input_data)
1891+
verify_model(Tan1().float().eval(), input_data=input_data)
1892+
verify_model(Tanh1().float().eval(), input_data=input_data)
1893+
verify_model(ATanh1().float().eval(), input_data=input_data)
1894+
verify_model(Log1().float().eval(), input_data=input_data)
1895+
verify_model(Exp1().float().eval(), input_data=input_data)
1896+
verify_model(Erf1().float().eval(), input_data=input_data)
1897+
verify_model(Trunc1().float().eval(), input_data=input_data)
1898+
verify_model(Sign1().float().eval(), input_data=input_data)
1899+
verify_model(Neg1().float().eval(), input_data=input_data)
1900+
1901+
18631902
if __name__ == "__main__":
18641903
# Single operator tests
18651904
test_forward_add()
@@ -1918,12 +1957,8 @@ def forward(self, *args):
19181957
test_forward_mean()
19191958
test_forward_expand()
19201959
test_forward_pow()
1921-
test_forward_abs()
1922-
test_forward_rsqrt()
1923-
test_forward_ceil()
1960+
test_forward_unary()
19241961
test_forward_clamp()
1925-
test_forward_floor()
1926-
test_forward_round()
19271962
test_forward_logical_not()
19281963
test_forward_bitwise_not()
19291964
test_forward_bitwise_xor()

0 commit comments

Comments
 (0)