Skip to content

Commit 3a814f6

Browse files
committed
Fix out of bounds integer behavior in full
1 parent 7c1d147 commit 3a814f6

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

dpctl/tensor/_ctors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,8 @@ def full(
11011101
fill_value = int(fill_value.real)
11021102
elif fill_value_type is complex and np.issubdtype(dtype, np.floating):
11031103
fill_value = fill_value.real
1104+
elif fill_value_type is int and np.issubdtype(dtype, np.integer):
1105+
fill_value = _to_scalar(fill_value, dtype)
11041106

11051107
hev, _ = ti._full_usm_ndarray(fill_value, res, sycl_queue)
11061108
hev.wait()

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,6 +1477,19 @@ def test_full_strides():
14771477
assert np.array_equal(dpt.asnumpy(X), Xnp)
14781478

14791479

1480+
def test_full_gh_1230():
1481+
q = get_queue_or_skip()
1482+
dtype = "i4"
1483+
dt_maxint = dpt.iinfo(dtype).max
1484+
X = dpt.full(1, dt_maxint + 1, dtype=dtype, sycl_queue=q)
1485+
X_np = dpt.asnumpy(X)
1486+
assert X.dtype == dpt.dtype(dtype)
1487+
assert np.array_equal(X_np, np.full_like(X_np, dt_maxint + 1))
1488+
1489+
with pytest.raises(OverflowError):
1490+
dpt.full(1, dpt.iinfo(dpt.uint64).max + 1, sycl_queue=q)
1491+
1492+
14801493
@pytest.mark.parametrize(
14811494
"dt",
14821495
_all_dtypes[1:],

0 commit comments

Comments
 (0)