Skip to content

Commit 22db299

Browse files
authored
[PYTORCH]Unary Ops (#5378)
1 parent c3511c5 commit 22db299

File tree

2 files changed

+114
-123
lines changed

2 files changed

+114
-123
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 26 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,16 @@ def _impl(inputs, input_types):
132132
return get_relay_op(name)(data0, data1)
133133
return _impl
134134

135-
def _abs():
135+
136+
def _unary(name):
136137
def _impl(inputs, input_types):
137-
data = inputs[0]
138-
return _op.abs(data)
138+
input_type = input_types[0]
139+
data = _convert_elemwise_input(inputs[0], input_type)
140+
141+
return get_relay_op(name)(data)
139142
return _impl
140143

144+
141145
def _arange():
142146
def _impl(inputs, input_types):
143147
if len(inputs) == 5:
@@ -1254,26 +1258,6 @@ def _impl(inputs, input_types):
12541258
return _op.nn.pad(data, pad_width, pad_value)
12551259
return _impl
12561260

1257-
def _sqrt():
1258-
def _impl(inputs, input_types):
1259-
data = inputs[0]
1260-
return _op.tensor.sqrt(data)
1261-
return _impl
1262-
1263-
1264-
def _rsqrt():
1265-
def _impl(inputs, input_types):
1266-
data = inputs[0]
1267-
return _op.tensor.rsqrt(data)
1268-
return _impl
1269-
1270-
1271-
def _ceil():
1272-
def _impl(inputs, input_types):
1273-
data = inputs[0]
1274-
return _op.ceil(data)
1275-
return _impl
1276-
12771261

12781262
def _clamp():
12791263
def _impl(inputs, input_types):
@@ -1284,20 +1268,6 @@ def _impl(inputs, input_types):
12841268
return _impl
12851269

12861270

1287-
def _floor():
1288-
def _impl(inputs, input_types):
1289-
data = inputs[0]
1290-
return _op.floor(data)
1291-
return _impl
1292-
1293-
1294-
def _round():
1295-
def _impl(inputs, input_types):
1296-
data = inputs[0]
1297-
return _op.round(data)
1298-
return _impl
1299-
1300-
13011271
def _to():
13021272
def _impl(inputs, input_types):
13031273
data = inputs[0]
@@ -1375,17 +1345,6 @@ def _impl(inputs, input_types):
13751345
return inputs[0]
13761346
return _impl
13771347

1378-
def _neg():
1379-
def _impl(inputs, input_types):
1380-
data = inputs[0]
1381-
return _op.tensor.negative(data)
1382-
return _impl
1383-
1384-
def _tanh():
1385-
def _impl(inputs, input_types):
1386-
data = inputs[0]
1387-
return _op.tensor.tanh(data)
1388-
return _impl
13891348

13901349
def _Bool():
13911350
def _impl(inputs, input_types):
@@ -1467,18 +1426,6 @@ def _impl(inputs, input_types):
14671426
return _impl
14681427

14691428

1470-
def _isfinite():
1471-
def _impl(inputs, input_types):
1472-
return _op.isfinite(inputs[0])
1473-
return _impl
1474-
1475-
1476-
def _isnan():
1477-
def _impl(inputs, input_types):
1478-
return _op.isnan(inputs[0])
1479-
return _impl
1480-
1481-
14821429
def _list_getitem(prelude):
14831430
def _impl(inputs, input_types):
14841431
return prelude.nth(inputs[0], _wrap_const(inputs[1]))
@@ -1601,7 +1548,6 @@ def _get_convert_map(prelude):
16011548
"aten::mul" : _elemwise("multiply"),
16021549
"aten::mul_" : _elemwise("multiply"),
16031550
"aten::pow" : _elemwise("power"),
1604-
"aten::abs" : _abs(),
16051551
"aten::arange" : _arange(),
16061552
"aten::div" : _elemwise("divide"),
16071553
"aten::div_" : _elemwise("divide"),
@@ -1683,12 +1629,26 @@ def _get_convert_map(prelude):
16831629
"aten::argmax" : _reduce("argmax"),
16841630
"aten::std" : _std(),
16851631
"aten::var" : _variance(),
1686-
"aten::sqrt" : _sqrt(),
1687-
"aten::rsqrt" : _rsqrt(),
1688-
"aten::ceil" : _ceil(),
1632+
"aten::abs" : _unary("abs"),
1633+
"aten::neg" : _unary("negative"),
1634+
"aten::cos" : _unary("cos"),
1635+
"aten::sin" : _unary("sin"),
1636+
"aten::tan" : _unary("tan"),
1637+
"aten::tanh" : _unary("tanh"),
1638+
"aten::atan" : _unary("atan"),
1639+
"aten::log" : _unary("log"),
1640+
"aten::exp" : _unary("exp"),
1641+
"aten::erf" : _unary("erf"),
1642+
"aten::trunc" : _unary("trunc"),
1643+
"aten::sign" : _unary("sign"),
1644+
"aten::sqrt" : _unary("sqrt"),
1645+
"aten::rsqrt" : _unary("rsqrt"),
1646+
"aten::ceil" : _unary("ceil"),
1647+
"aten::floor" : _unary("floor"),
1648+
"aten::round" : _unary("round"),
1649+
"aten::isfinite" : _unary("isfinite"),
1650+
"aten::isnan" : _unary("isnan"),
16891651
"aten::clamp" : _clamp(),
1690-
"aten::floor" : _floor(),
1691-
"aten::round" : _round(),
16921652
"aten::detach" : _identity(),
16931653
"aten::upsample_bilinear2d" : _upsample("bilinear"),
16941654
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
@@ -1703,12 +1663,8 @@ def _get_convert_map(prelude):
17031663
"aten::logical_xor" : _logical_xor(),
17041664
"aten::bitwise_not" : _bitwise_not(),
17051665
"aten::bitwise_xor" : _bitwise_xor(),
1706-
"aten::isfinite" : _isfinite(),
1707-
"aten::isnan" : _isnan(),
17081666
"aten::Bool" : _Bool(),
17091667
"aten::Float" : _Float(),
1710-
"aten::neg" : _neg(),
1711-
"aten::tanh" : _tanh(),
17121668
"aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
17131669
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(),
17141670
"aten::mm" : _matmul(),

tests/python/frontend/pytorch/test_forward.py

Lines changed: 88 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,30 +1497,6 @@ def forward(self, *args):
14971497
verify_model(IsInf1().float().eval(), input_data=input_data)
14981498

14991499

1500-
def test_forward_rsqrt():
1501-
torch.set_grad_enabled(False)
1502-
input_shape = [1, 3, 10, 10]
1503-
1504-
class Rsqrt1(Module):
1505-
def forward(self, *args):
1506-
return torch.rsqrt(args[0])
1507-
1508-
input_data = torch.rand(input_shape).float()
1509-
verify_model(Rsqrt1().float().eval(), input_data=input_data)
1510-
1511-
1512-
def test_forward_ceil():
1513-
torch.set_grad_enabled(False)
1514-
input_shape = [1, 3, 10, 10]
1515-
1516-
class Ceil1(Module):
1517-
def forward(self, *args):
1518-
return torch.ceil(args[0])
1519-
1520-
input_data = torch.rand(input_shape).float()
1521-
verify_model(Ceil1().float().eval(), input_data=input_data)
1522-
1523-
15241500
def test_forward_clamp():
15251501
torch.set_grad_enabled(False)
15261502
input_shape = [1, 3, 10, 10]
@@ -1543,30 +1519,6 @@ def forward(self, *args):
15431519
verify_model(Clamp3().float().eval(), input_data=input_data)
15441520

15451521

1546-
def test_forward_floor():
1547-
torch.set_grad_enabled(False)
1548-
input_shape = [1, 3, 10, 10]
1549-
1550-
class Floor1(Module):
1551-
def forward(self, *args):
1552-
return torch.floor(args[0])
1553-
1554-
input_data = torch.rand(input_shape).float()
1555-
verify_model(Floor1().float().eval(), input_data=input_data)
1556-
1557-
1558-
def test_forward_round():
1559-
torch.set_grad_enabled(False)
1560-
input_shape = [1, 3, 10, 10]
1561-
1562-
class Round1(Module):
1563-
def forward(self, *args):
1564-
return torch.round(args[0])
1565-
1566-
input_data = torch.rand(input_shape).float()
1567-
verify_model(Round1().float().eval(), input_data=input_data)
1568-
1569-
15701522
def test_forward_ones():
15711523
torch.set_grad_enabled(False)
15721524

@@ -1849,6 +1801,93 @@ def forward(self, *args):
18491801
verify_model(LogicalXor2().float().eval(), input_data=[lhs])
18501802

18511803

1804+
def test_forward_unary():
1805+
torch.set_grad_enabled(False)
1806+
1807+
class Sqrt1(Module):
1808+
def forward(self, *args):
1809+
return torch.sqrt(args[0])
1810+
1811+
class RSqrt1(Module):
1812+
def forward(self, *args):
1813+
return torch.rsqrt(args[0])
1814+
1815+
class Ceil1(Module):
1816+
def forward(self, *args):
1817+
return torch.ceil(args[0])
1818+
1819+
class Floor1(Module):
1820+
def forward(self, *args):
1821+
return torch.floor(args[0])
1822+
1823+
class Round1(Module):
1824+
def forward(self, *args):
1825+
return torch.round(args[0])
1826+
1827+
class Cos1(Module):
1828+
def forward(self, *args):
1829+
return torch.cos(args[0])
1830+
1831+
class Sin1(Module):
1832+
def forward(self, *args):
1833+
return torch.sin(args[0])
1834+
1835+
class Tan1(Module):
1836+
def forward(self, *args):
1837+
return torch.tan(args[0])
1838+
1839+
class Tanh1(Module):
1840+
def forward(self, *args):
1841+
return torch.tanh(args[0])
1842+
1843+
class ATanh1(Module):
1844+
def forward(self, *args):
1845+
return torch.atan(args[0])
1846+
1847+
class Log1(Module):
1848+
def forward(self, *args):
1849+
return torch.log(args[0])
1850+
1851+
class Exp1(Module):
1852+
def forward(self, *args):
1853+
return torch.exp(args[0])
1854+
1855+
class Erf1(Module):
1856+
def forward(self, *args):
1857+
return torch.erf(args[0])
1858+
1859+
class Trunc1(Module):
1860+
def forward(self, *args):
1861+
return torch.trunc(args[0])
1862+
1863+
class Sign1(Module):
1864+
def forward(self, *args):
1865+
return torch.sign(args[0])
1866+
1867+
class Neg1(Module):
1868+
def forward(self, *args):
1869+
return torch.neg(args[0])
1870+
1871+
input_shape = [1, 3, 10, 10]
1872+
input_data = torch.rand(input_shape).float()
1873+
verify_model(Sqrt1().float().eval(), input_data=input_data)
1874+
verify_model(RSqrt1().float().eval(), input_data=input_data)
1875+
verify_model(Ceil1().float().eval(), input_data=input_data)
1876+
verify_model(Floor1().float().eval(), input_data=input_data)
1877+
verify_model(Round1().float().eval(), input_data=input_data)
1878+
verify_model(Cos1().float().eval(), input_data=input_data)
1879+
verify_model(Sin1().float().eval(), input_data=input_data)
1880+
verify_model(Tan1().float().eval(), input_data=input_data)
1881+
verify_model(Tanh1().float().eval(), input_data=input_data)
1882+
verify_model(ATanh1().float().eval(), input_data=input_data)
1883+
verify_model(Log1().float().eval(), input_data=input_data)
1884+
verify_model(Exp1().float().eval(), input_data=input_data)
1885+
verify_model(Erf1().float().eval(), input_data=input_data)
1886+
verify_model(Trunc1().float().eval(), input_data=input_data)
1887+
verify_model(Sign1().float().eval(), input_data=input_data)
1888+
verify_model(Neg1().float().eval(), input_data=input_data)
1889+
1890+
18521891
if __name__ == "__main__":
18531892
# Single operator tests
18541893
test_forward_add()
@@ -1907,12 +1946,8 @@ def forward(self, *args):
19071946
test_forward_mean()
19081947
test_forward_expand()
19091948
test_forward_pow()
1910-
test_forward_abs()
1911-
test_forward_rsqrt()
1912-
test_forward_ceil()
1949+
test_forward_unary()
19131950
test_forward_clamp()
1914-
test_forward_floor()
1915-
test_forward_round()
19161951
test_forward_logical_not()
19171952
test_forward_bitwise_not()
19181953
test_forward_bitwise_xor()

0 commit comments

Comments
 (0)