Skip to content

Commit 59c3d32

Browse files
justinchubyCopilot
andauthored
[torchlib] Fix implementations for bitwise_* overloads (#2618)
Some overloads for bitwise_* can accept scalar inputs which do not have the dtype. This PR creates implementations for the overloads. Fix #2617 --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent cb6f873 commit 59c3d32

File tree

2 files changed

+122
-35
lines changed

2 files changed

+122
-35
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 109 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,51 +1220,68 @@ def aten_binomial(
12201220
@torch_op(
12211221
(
12221222
"aten::bitwise_and.Tensor",
1223-
"aten::bitwise_and.Scalar",
1224-
"aten::bitwise_and.Scalar_Tensor",
12251223
"_operator::and_",
12261224
),
12271225
trace_only=True,
12281226
)
12291227
def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor:
12301228
"""bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""
12311229

1232-
assert self.dtype == other.dtype
1230+
assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
1231+
dtype = self.dtype if self.dtype is not None else other.dtype
1232+
assert dtype is not None
12331233

1234-
if self.dtype.is_integer():
1234+
if dtype.is_integer():
12351235
return op.BitwiseAnd(self, other)
1236-
if self.dtype == ir.DataType.BOOL:
1236+
if dtype == ir.DataType.BOOL:
12371237
return op.And(self, other)
12381238
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
12391239

12401240

1241+
@torch_op("aten::bitwise_and.Scalar", trace_only=True)
1242+
def aten_bitwise_and_scalar(self: TTensor, other: int) -> TTensor:
1243+
"""bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor"""
1244+
1245+
other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
1246+
return aten_bitwise_and(self, other_tensor)
1247+
1248+
1249+
@torch_op("aten::bitwise_and.Scalar_Tensor", trace_only=True)
1250+
def aten_bitwise_and_scalar_tensor(self: float, other: TTensor) -> TTensor:
1251+
"""bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
1252+
1253+
self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
1254+
return aten_bitwise_and(self_tensor, other)
1255+
1256+
12411257
@torch_op(
12421258
(
12431259
"aten::bitwise_left_shift.Tensor",
1244-
"aten::bitwise_left_shift.Tensor_Scalar",
1245-
"aten::bitwise_left_shift.Scalar_Tensor",
12461260
"_operator::__lshift__",
1247-
"aten::__lshift__.Scalar",
12481261
),
12491262
trace_only=True,
12501263
)
12511264
def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt:
12521265
"""bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
1266+
assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
1267+
dtype = self.dtype if self.dtype is not None else other.dtype
1268+
assert dtype is not None
1269+
12531270
# assert other >= 0
1254-
if self.dtype.bitwidth == 8:
1271+
if dtype.bitwidth == 8:
12551272
unsigned_dtype = ir.DataType.UINT8
12561273
signed_dtype = ir.DataType.INT8
1257-
elif self.dtype.bitwidth == 16:
1274+
elif dtype.bitwidth == 16:
12581275
unsigned_dtype = ir.DataType.UINT16
12591276
signed_dtype = ir.DataType.INT16
1260-
elif self.dtype.bitwidth == 32:
1277+
elif dtype.bitwidth == 32:
12611278
unsigned_dtype = ir.DataType.UINT32
12621279
signed_dtype = ir.DataType.INT32
1263-
elif self.dtype.bitwidth == 64:
1280+
elif dtype.bitwidth == 64:
12641281
unsigned_dtype = ir.DataType.UINT64
12651282
signed_dtype = ir.DataType.INT64
12661283
else:
1267-
raise NotImplementedError(f"Not implemented for type {self.dtype}")
1284+
raise NotImplementedError(f"Not implemented for type {dtype}")
12681285

12691286
self = op.Cast(self, to=unsigned_dtype)
12701287
other = op.Cast(other, to=unsigned_dtype)
@@ -1274,6 +1291,22 @@ def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt:
12741291
return op.Cast(result, to=signed_dtype)
12751292

12761293

1294+
@torch_op(
1295+
("aten::bitwise_left_shift.Tensor_Scalar", "aten::__lshift__.Scalar"), trace_only=True
1296+
)
1297+
def aten_bitwise_left_shift_tensor_scalar(self: TInt, other: int) -> TInt:
1298+
"""bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor"""
1299+
other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
1300+
return aten_bitwise_left_shift(self, other_tensor)
1301+
1302+
1303+
@torch_op("aten::bitwise_left_shift.Scalar_Tensor", trace_only=True)
1304+
def aten_bitwise_left_shift_scalar_tensor(self: int, other: TInt) -> TInt:
1305+
"""bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
1306+
self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
1307+
return aten_bitwise_left_shift(self_tensor, other)
1308+
1309+
12771310
@torch_op("aten::bitwise_not", trace_only=True)
12781311
def aten_bitwise_not(self: TTensor) -> TTensor:
12791312
"""bitwise_not(Tensor self) -> Tensor"""
@@ -1288,54 +1321,69 @@ def aten_bitwise_not(self: TTensor) -> TTensor:
12881321
@torch_op(
12891322
(
12901323
"aten::bitwise_or.Tensor",
1291-
"aten::bitwise_or.Scalar",
1292-
"aten::bitwise_or.Scalar_Tensor",
12931324
"_operator::or_",
12941325
),
12951326
trace_only=True,
12961327
)
12971328
def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor:
12981329
"""bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""
12991330

1300-
assert self.dtype == other.dtype
1331+
assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
1332+
dtype = self.dtype if self.dtype is not None else other.dtype
1333+
assert dtype is not None
13011334

1302-
if self.dtype.is_integer():
1335+
if dtype.is_integer():
13031336
return op.BitwiseOr(self, other)
1304-
if self.dtype == ir.DataType.BOOL:
1337+
if dtype == ir.DataType.BOOL:
13051338
return op.Or(self, other)
13061339
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
13071340

13081341

1342+
@torch_op("aten::bitwise_or.Scalar", trace_only=True)
1343+
def aten_bitwise_or_scalar(self: TTensor, other: int) -> TTensor:
1344+
"""bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor"""
1345+
other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
1346+
return aten_bitwise_or(self, other_tensor)
1347+
1348+
1349+
@torch_op("aten::bitwise_or.Scalar_Tensor", trace_only=True)
1350+
def aten_bitwise_or_scalar_tensor(self: int, other: TTensor) -> TTensor:
1351+
"""bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
1352+
self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
1353+
return aten_bitwise_or(self_tensor, other)
1354+
1355+
13091356
@torch_op(
13101357
(
13111358
"aten::bitwise_right_shift.Tensor",
1312-
"aten::bitwise_right_shift.Tensor_Scalar",
1313-
"aten::bitwise_right_shift.Scalar_Tensor",
13141359
"_operator::__rshift__",
1315-
"aten::__rshift__.Scalar",
13161360
),
13171361
trace_only=True,
13181362
)
13191363
def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt:
13201364
"""bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
1321-
if self.dtype.bitwidth == 8:
1365+
assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
1366+
dtype = self.dtype if self.dtype is not None else other.dtype
1367+
assert dtype is not None
1368+
1369+
if dtype.bitwidth == 8:
13221370
unsigned_dtype = ir.DataType.UINT8
13231371
signed_dtype = ir.DataType.INT8
13241372
mask = ir.tensor(0xFF, dtype=unsigned_dtype)
1325-
elif self.dtype.bitwidth == 16:
1373+
elif dtype.bitwidth == 16:
13261374
unsigned_dtype = ir.DataType.UINT16
13271375
signed_dtype = ir.DataType.INT16
13281376
mask = ir.tensor(0xFFFF, dtype=unsigned_dtype)
1329-
elif self.dtype.bitwidth == 32:
1377+
elif dtype.bitwidth == 32:
13301378
unsigned_dtype = ir.DataType.UINT32
13311379
signed_dtype = ir.DataType.INT32
13321380
mask = ir.tensor(0xFFFFFFFF, dtype=unsigned_dtype)
1333-
elif self.dtype.bitwidth == 64:
1381+
elif dtype.bitwidth == 64:
13341382
unsigned_dtype = ir.DataType.UINT64
13351383
signed_dtype = ir.DataType.INT64
13361384
mask = ir.tensor(0xFFFFFFFFFFFFFFFF, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF
13371385
else:
1338-
raise NotImplementedError(f"Not implemented for type {self.dtype}")
1386+
raise NotImplementedError(f"Not implemented for type {dtype}")
13391387

13401388
negative = op.Less(self, 0)
13411389
self = op.Cast(self, to=unsigned_dtype)
@@ -1356,24 +1404,50 @@ def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt:
13561404

13571405

13581406
@torch_op(
1359-
(
1360-
"aten::bitwise_xor.Tensor",
1361-
"aten::bitwise_xor.Scalar",
1362-
"aten::bitwise_xor.Scalar_Tensor",
1363-
),
1364-
trace_only=True,
1407+
("aten::bitwise_right_shift.Tensor_Scalar", "aten::__rshift__.Scalar"), trace_only=True
13651408
)
1409+
def aten_bitwise_right_shift_tensor_scalar(self: TInt, other: int) -> TInt:
1410+
"""bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor"""
1411+
other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
1412+
return aten_bitwise_right_shift(self, other_tensor)
1413+
1414+
1415+
@torch_op("aten::bitwise_right_shift.Scalar_Tensor", trace_only=True)
1416+
def aten_bitwise_right_shift_scalar_tensor(self: int, other: TInt) -> TInt:
1417+
"""bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
1418+
self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
1419+
return aten_bitwise_right_shift(self_tensor, other)
1420+
1421+
1422+
@torch_op("aten::bitwise_xor.Tensor", trace_only=True)
13661423
def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor:
13671424
"""bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
1368-
assert self.dtype == other.dtype
13691425

1370-
if self.dtype.is_integer():
1426+
assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
1427+
dtype = self.dtype if self.dtype is not None else other.dtype
1428+
assert dtype is not None
1429+
1430+
if dtype.is_integer():
13711431
return op.BitwiseXor(self, other)
1372-
if self.dtype == ir.DataType.BOOL:
1432+
if dtype == ir.DataType.BOOL:
13731433
return op.Xor(self, other)
13741434
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
13751435

13761436

1437+
@torch_op("aten::bitwise_xor.Scalar", trace_only=True)
1438+
def aten_bitwise_xor_scalar(self: TTensor, other: int) -> TTensor:
1439+
"""bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor"""
1440+
other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
1441+
return aten_bitwise_xor(self, other_tensor)
1442+
1443+
1444+
@torch_op("aten::bitwise_xor.Scalar_Tensor", trace_only=True)
1445+
def aten_bitwise_xor_scalar_tensor(self: int, other: TTensor) -> TTensor:
1446+
"""bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
1447+
self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
1448+
return aten_bitwise_xor(self_tensor, other)
1449+
1450+
13771451
@torch_op("aten::blackman_window", trace_only=True)
13781452
def aten_blackman_window(
13791453
window_length: int,

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,19 @@ def forward(self, q, k, v):
225225
)
226226
_testing.assert_onnx_program(onnx_program)
227227

228+
def test_bitwise_and_scalar(self):
229+
class Model(torch.nn.Module):
230+
def forward(self, x):
231+
return x & 3
232+
233+
onnx_program = torch.onnx.export(
234+
Model(),
235+
(torch.tensor([1, 2, 3, 4, 5]),),
236+
dynamo=True,
237+
verbose=False,
238+
)
239+
_testing.assert_onnx_program(onnx_program)
240+
228241

229242
if __name__ == "__main__":
230243
unittest.main()

0 commit comments

Comments
 (0)