@@ -1220,51 +1220,68 @@ def aten_binomial(
12201220@torch_op (
12211221 (
12221222 "aten::bitwise_and.Tensor" ,
1223- "aten::bitwise_and.Scalar" ,
1224- "aten::bitwise_and.Scalar_Tensor" ,
12251223 "_operator::and_" ,
12261224 ),
12271225 trace_only = True ,
12281226)
12291227def aten_bitwise_and (self : TTensor , other : TTensor ) -> TTensor :
12301228 """bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""
12311229
1232- assert self .dtype == other .dtype
1230+ assert self .dtype == other .dtype or self .dtype is None or other .dtype is None
1231+ dtype = self .dtype if self .dtype is not None else other .dtype
1232+ assert dtype is not None
12331233
1234- if self . dtype .is_integer ():
1234+ if dtype .is_integer ():
12351235 return op .BitwiseAnd (self , other )
1236- if self . dtype == ir .DataType .BOOL :
1236+ if dtype == ir .DataType .BOOL :
12371237 return op .And (self , other )
12381238 raise NotImplementedError (f"Not implemented for types { self .dtype } and { other .dtype } " )
12391239
12401240
1241+ @torch_op ("aten::bitwise_and.Scalar" , trace_only = True )
1242+ def aten_bitwise_and_scalar (self : TTensor , other : int ) -> TTensor :
1243+ """bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor"""
1244+
1245+ other_tensor = op .Constant (value = ir .tensor (other , dtype = self .dtype ))
1246+ return aten_bitwise_and (self , other_tensor )
1247+
1248+
1249+ @torch_op ("aten::bitwise_and.Scalar_Tensor" , trace_only = True )
1250+ def aten_bitwise_and_scalar_tensor (self : float , other : TTensor ) -> TTensor :
1251+ """bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
1252+
1253+ self_tensor = op .Constant (value = ir .tensor (self , dtype = other .dtype ))
1254+ return aten_bitwise_and (self_tensor , other )
1255+
1256+
12411257@torch_op (
12421258 (
12431259 "aten::bitwise_left_shift.Tensor" ,
1244- "aten::bitwise_left_shift.Tensor_Scalar" ,
1245- "aten::bitwise_left_shift.Scalar_Tensor" ,
12461260 "_operator::__lshift__" ,
1247- "aten::__lshift__.Scalar" ,
12481261 ),
12491262 trace_only = True ,
12501263)
12511264def aten_bitwise_left_shift (self : TInt , other : TInt ) -> TInt :
12521265 """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
1266+ assert self .dtype == other .dtype or self .dtype is None or other .dtype is None
1267+ dtype = self .dtype if self .dtype is not None else other .dtype
1268+ assert dtype is not None
1269+
12531270 # assert other >= 0
1254- if self . dtype .bitwidth == 8 :
1271+ if dtype .bitwidth == 8 :
12551272 unsigned_dtype = ir .DataType .UINT8
12561273 signed_dtype = ir .DataType .INT8
1257- elif self . dtype .bitwidth == 16 :
1274+ elif dtype .bitwidth == 16 :
12581275 unsigned_dtype = ir .DataType .UINT16
12591276 signed_dtype = ir .DataType .INT16
1260- elif self . dtype .bitwidth == 32 :
1277+ elif dtype .bitwidth == 32 :
12611278 unsigned_dtype = ir .DataType .UINT32
12621279 signed_dtype = ir .DataType .INT32
1263- elif self . dtype .bitwidth == 64 :
1280+ elif dtype .bitwidth == 64 :
12641281 unsigned_dtype = ir .DataType .UINT64
12651282 signed_dtype = ir .DataType .INT64
12661283 else :
1267- raise NotImplementedError (f"Not implemented for type { self . dtype } " )
1284+ raise NotImplementedError (f"Not implemented for type { dtype } " )
12681285
12691286 self = op .Cast (self , to = unsigned_dtype )
12701287 other = op .Cast (other , to = unsigned_dtype )
@@ -1274,6 +1291,22 @@ def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt:
12741291 return op .Cast (result , to = signed_dtype )
12751292
12761293
1294+ @torch_op (
1295+ ("aten::bitwise_left_shift.Tensor_Scalar" , "aten::__lshift__.Scalar" ), trace_only = True
1296+ )
1297+ def aten_bitwise_left_shift_tensor_scalar (self : TInt , other : int ) -> TInt :
1298+ """bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor"""
1299+ other_tensor = op .Constant (value = ir .tensor (other , dtype = self .dtype ))
1300+ return aten_bitwise_left_shift (self , other_tensor )
1301+
1302+
1303+ @torch_op ("aten::bitwise_left_shift.Scalar_Tensor" , trace_only = True )
1304+ def aten_bitwise_left_shift_scalar_tensor (self : int , other : TInt ) -> TInt :
1305+ """bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
1306+ self_tensor = op .Constant (value = ir .tensor (self , dtype = other .dtype ))
1307+ return aten_bitwise_left_shift (self_tensor , other )
1308+
1309+
12771310@torch_op ("aten::bitwise_not" , trace_only = True )
12781311def aten_bitwise_not (self : TTensor ) -> TTensor :
12791312 """bitwise_not(Tensor self) -> Tensor"""
@@ -1288,54 +1321,69 @@ def aten_bitwise_not(self: TTensor) -> TTensor:
12881321@torch_op (
12891322 (
12901323 "aten::bitwise_or.Tensor" ,
1291- "aten::bitwise_or.Scalar" ,
1292- "aten::bitwise_or.Scalar_Tensor" ,
12931324 "_operator::or_" ,
12941325 ),
12951326 trace_only = True ,
12961327)
12971328def aten_bitwise_or (self : TTensor , other : TTensor ) -> TTensor :
12981329 """bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""
12991330
1300- assert self .dtype == other .dtype
1331+ assert self .dtype == other .dtype or self .dtype is None or other .dtype is None
1332+ dtype = self .dtype if self .dtype is not None else other .dtype
1333+ assert dtype is not None
13011334
1302- if self . dtype .is_integer ():
1335+ if dtype .is_integer ():
13031336 return op .BitwiseOr (self , other )
1304- if self . dtype == ir .DataType .BOOL :
1337+ if dtype == ir .DataType .BOOL :
13051338 return op .Or (self , other )
13061339 raise NotImplementedError (f"Not implemented for types { self .dtype } and { other .dtype } " )
13071340
13081341
1342+ @torch_op ("aten::bitwise_or.Scalar" , trace_only = True )
1343+ def aten_bitwise_or_scalar (self : TTensor , other : int ) -> TTensor :
1344+ """bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor"""
1345+ other_tensor = op .Constant (value = ir .tensor (other , dtype = self .dtype ))
1346+ return aten_bitwise_or (self , other_tensor )
1347+
1348+
1349+ @torch_op ("aten::bitwise_or.Scalar_Tensor" , trace_only = True )
1350+ def aten_bitwise_or_scalar_tensor (self : int , other : TTensor ) -> TTensor :
1351+ """bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
1352+ self_tensor = op .Constant (value = ir .tensor (self , dtype = other .dtype ))
1353+ return aten_bitwise_or (self_tensor , other )
1354+
1355+
13091356@torch_op (
13101357 (
13111358 "aten::bitwise_right_shift.Tensor" ,
1312- "aten::bitwise_right_shift.Tensor_Scalar" ,
1313- "aten::bitwise_right_shift.Scalar_Tensor" ,
13141359 "_operator::__rshift__" ,
1315- "aten::__rshift__.Scalar" ,
13161360 ),
13171361 trace_only = True ,
13181362)
13191363def aten_bitwise_right_shift (self : TInt , other : TInt ) -> TInt :
13201364 """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
1321- if self .dtype .bitwidth == 8 :
1365+ assert self .dtype == other .dtype or self .dtype is None or other .dtype is None
1366+ dtype = self .dtype if self .dtype is not None else other .dtype
1367+ assert dtype is not None
1368+
1369+ if dtype .bitwidth == 8 :
13221370 unsigned_dtype = ir .DataType .UINT8
13231371 signed_dtype = ir .DataType .INT8
13241372 mask = ir .tensor (0xFF , dtype = unsigned_dtype )
1325- elif self . dtype .bitwidth == 16 :
1373+ elif dtype .bitwidth == 16 :
13261374 unsigned_dtype = ir .DataType .UINT16
13271375 signed_dtype = ir .DataType .INT16
13281376 mask = ir .tensor (0xFFFF , dtype = unsigned_dtype )
1329- elif self . dtype .bitwidth == 32 :
1377+ elif dtype .bitwidth == 32 :
13301378 unsigned_dtype = ir .DataType .UINT32
13311379 signed_dtype = ir .DataType .INT32
13321380 mask = ir .tensor (0xFFFFFFFF , dtype = unsigned_dtype )
1333- elif self . dtype .bitwidth == 64 :
1381+ elif dtype .bitwidth == 64 :
13341382 unsigned_dtype = ir .DataType .UINT64
13351383 signed_dtype = ir .DataType .INT64
13361384 mask = ir .tensor (0xFFFFFFFFFFFFFFFF , dtype = unsigned_dtype ) # 0xFFFFFFFFFFFFFFFF
13371385 else :
1338- raise NotImplementedError (f"Not implemented for type { self . dtype } " )
1386+ raise NotImplementedError (f"Not implemented for type { dtype } " )
13391387
13401388 negative = op .Less (self , 0 )
13411389 self = op .Cast (self , to = unsigned_dtype )
@@ -1356,24 +1404,50 @@ def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt:
13561404
13571405
13581406@torch_op (
1359- (
1360- "aten::bitwise_xor.Tensor" ,
1361- "aten::bitwise_xor.Scalar" ,
1362- "aten::bitwise_xor.Scalar_Tensor" ,
1363- ),
1364- trace_only = True ,
1407+ ("aten::bitwise_right_shift.Tensor_Scalar" , "aten::__rshift__.Scalar" ), trace_only = True
13651408)
1409+ def aten_bitwise_right_shift_tensor_scalar (self : TInt , other : int ) -> TInt :
1410+ """bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor"""
1411+ other_tensor = op .Constant (value = ir .tensor (other , dtype = self .dtype ))
1412+ return aten_bitwise_right_shift (self , other_tensor )
1413+
1414+
1415+ @torch_op ("aten::bitwise_right_shift.Scalar_Tensor" , trace_only = True )
1416+ def aten_bitwise_right_shift_scalar_tensor (self : int , other : TInt ) -> TInt :
1417+ """bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
1418+ self_tensor = op .Constant (value = ir .tensor (self , dtype = other .dtype ))
1419+ return aten_bitwise_right_shift (self_tensor , other )
1420+
1421+
1422+ @torch_op ("aten::bitwise_xor.Tensor" , trace_only = True )
13661423def aten_bitwise_xor (self : TTensor , other : TTensor ) -> TTensor :
13671424 """bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
1368- assert self .dtype == other .dtype
13691425
1370- if self .dtype .is_integer ():
1426+ assert self .dtype == other .dtype or self .dtype is None or other .dtype is None
1427+ dtype = self .dtype if self .dtype is not None else other .dtype
1428+ assert dtype is not None
1429+
1430+ if dtype .is_integer ():
13711431 return op .BitwiseXor (self , other )
1372- if self . dtype == ir .DataType .BOOL :
1432+ if dtype == ir .DataType .BOOL :
13731433 return op .Xor (self , other )
13741434 raise NotImplementedError (f"Not implemented for types { self .dtype } and { other .dtype } " )
13751435
13761436
1437+ @torch_op ("aten::bitwise_xor.Scalar" , trace_only = True )
1438+ def aten_bitwise_xor_scalar (self : TTensor , other : int ) -> TTensor :
1439+ """bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor"""
1440+ other_tensor = op .Constant (value = ir .tensor (other , dtype = self .dtype ))
1441+ return aten_bitwise_xor (self , other_tensor )
1442+
1443+
1444+ @torch_op ("aten::bitwise_xor.Scalar_Tensor" , trace_only = True )
1445+ def aten_bitwise_xor_scalar_tensor (self : int , other : TTensor ) -> TTensor :
1446+ """bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
1447+ self_tensor = op .Constant (value = ir .tensor (self , dtype = other .dtype ))
1448+ return aten_bitwise_xor (self_tensor , other )
1449+
1450+
13771451@torch_op ("aten::blackman_window" , trace_only = True )
13781452def aten_blackman_window (
13791453 window_length : int ,
0 commit comments