@@ -1600,6 +1600,95 @@ def forward(self, *args):
16001600 verify_model (Topk6 ().float ().eval (), input_data = input_data )
16011601
16021602
1603+ def test_forward_logical_not ():
1604+ torch .set_grad_enabled (False )
1605+
1606+ class LogicalNot1 (Module ):
1607+ def forward (self , * args ):
1608+ return torch .logical_not (args [0 ])
1609+
1610+ input_data = torch .tensor ([True , False ])
1611+ verify_model (LogicalNot1 ().float ().eval (), input_data = input_data )
1612+
1613+ input_data = torch .tensor ([0 , 1 , - 10 ], dtype = torch .int8 )
1614+ verify_model (LogicalNot1 ().float ().eval (), input_data = input_data )
1615+
1616+ input_data = torch .tensor ([0. , 1.5 , - 10. ], dtype = torch .double )
1617+ verify_model (LogicalNot1 ().float ().eval (), input_data = input_data )
1618+
1619+ input_data = torch .tensor ([0. , 1. , - 10. ], dtype = torch .int32 )
1620+ verify_model (LogicalNot1 ().float ().eval (), input_data = input_data )
1621+
1622+
1623+ def test_forward_bitwise_not ():
1624+ torch .set_grad_enabled (False )
1625+
1626+ class BitwiseNot1 (Module ):
1627+ def forward (self , * args ):
1628+ return torch .bitwise_not (args [0 ])
1629+
1630+ input_data = torch .tensor ([0 , 1 , - 10 ], dtype = torch .int8 )
1631+ verify_model (BitwiseNot1 ().float ().eval (), input_data = input_data )
1632+
1633+ input_data = torch .tensor ([0. , 1. , - 10. ], dtype = torch .int32 )
1634+ verify_model (BitwiseNot1 ().float ().eval (), input_data = input_data )
1635+
1636+ input_data = torch .tensor ([True , False ])
1637+ verify_model (BitwiseNot1 ().float ().eval (), input_data = input_data )
1638+
1639+
1640+ def test_forward_bitwise_xor ():
1641+ torch .set_grad_enabled (False )
1642+
1643+ class BitwiseXor1 (Module ):
1644+ def forward (self , * args ):
1645+ return torch .bitwise_xor (args [0 ], args [1 ])
1646+
1647+ class BitwiseXor2 (Module ):
1648+ def forward (self , * args ):
1649+ rhs = torch .tensor ([1 , 0 , 3 ], dtype = torch .int8 )
1650+ if torch .cuda .is_available ():
1651+ rhs = rhs .cuda ()
1652+ return torch .bitwise_xor (args [0 ], rhs )
1653+
1654+ lhs = torch .tensor ([- 1 , - 2 , 3 ], dtype = torch .int8 )
1655+ rhs = torch .tensor ([1 , 0 , 3 ], dtype = torch .int8 )
1656+ verify_model (BitwiseXor1 ().float ().eval (), input_data = [lhs , rhs ])
1657+
1658+ lhs = torch .tensor ([True , True , False ])
1659+ rhs = torch .tensor ([False , True , False ])
1660+ verify_model (BitwiseXor1 ().float ().eval (), input_data = [lhs , rhs ])
1661+
1662+ lhs = torch .tensor ([- 1 , - 2 , 3 ], dtype = torch .int8 )
1663+ verify_model (BitwiseXor2 ().float ().eval (), input_data = [lhs ])
1664+
1665+
1666+ def test_forward_logical_xor ():
1667+ torch .set_grad_enabled (False )
1668+
1669+ class LogicalXor1 (Module ):
1670+ def forward (self , * args ):
1671+ return torch .logical_xor (args [0 ], args [1 ])
1672+
1673+ class LogicalXor2 (Module ):
1674+ def forward (self , * args ):
1675+ rhs = torch .tensor ([1 , 0 , 3 ], dtype = torch .int8 )
1676+ if torch .cuda .is_available ():
1677+ rhs = rhs .cuda ()
1678+ return torch .logical_xor (args [0 ], rhs )
1679+
1680+ lhs = torch .tensor ([- 1 , - 2 , 3 ], dtype = torch .int8 )
1681+ rhs = torch .tensor ([1 , 0 , 3 ], dtype = torch .int8 )
1682+ verify_model (LogicalXor1 ().float ().eval (), input_data = [lhs , rhs ])
1683+
1684+ lhs = torch .tensor ([True , True , False ])
1685+ rhs = torch .tensor ([False , True , False ])
1686+ verify_model (LogicalXor1 ().float ().eval (), input_data = [lhs , rhs ])
1687+
1688+ lhs = torch .tensor ([- 1 , - 2 , 3 ], dtype = torch .int8 )
1689+ verify_model (LogicalXor2 ().float ().eval (), input_data = [lhs ])
1690+
1691+
16031692if __name__ == "__main__" :
16041693 # Single operator tests
16051694 test_forward_add ()
@@ -1663,6 +1752,10 @@ def forward(self, *args):
16631752 test_forward_clamp ()
16641753 test_forward_floor ()
16651754 test_forward_round ()
1755+ test_forward_logical_not ()
1756+ test_forward_bitwise_not ()
1757+ test_forward_bitwise_xor ()
1758+ test_forward_logical_xor ()
16661759 test_forward_isfinite ()
16671760 test_forward_isnan ()
16681761 test_forward_isinf ()
0 commit comments