Skip to content

Commit e7e8508

Browse files
committed
Fixed error in cast dtype for full() function.
1 parent 05c358e commit e7e8508

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

dpctl/tensor/_ctors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,14 +773,20 @@ def full(
773773

774774
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
775775
usm_type = usm_type if usm_type is not None else "device"
776-
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
776+
fill_value_type = type(fill_value)
777+
dtype = _get_dtype(dtype, sycl_queue, ref_type=fill_value_type)
777778
res = dpt.usm_ndarray(
778779
sh,
779780
dtype=dtype,
780781
buffer=usm_type,
781782
order=order,
782783
buffer_ctor_kwargs={"queue": sycl_queue},
783784
)
785+
if fill_value_type in [float, complex] and np.issubdtype(dtype, np.integer):
786+
fill_value = int(fill_value.real)
787+
elif fill_value_type is complex and np.issubdtype(dtype, np.floating):
788+
fill_value = fill_value.real
789+
784790
hev, _ = ti._full_usm_ndarray(fill_value, res, sycl_queue)
785791
hev.wait()
786792
return res

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,10 @@ def test_full_dtype_inference():
991991
assert np.issubdtype(dpt.full(10, 12.3).dtype, np.floating)
992992
assert np.issubdtype(dpt.full(10, 0.3 - 2j).dtype, np.complexfloating)
993993

994+
assert np.issubdtype(dpt.full(10, 12.3, dtype=int).dtype, np.integer)
995+
assert np.issubdtype(dpt.full(10, 0.3 - 2j, dtype=int).dtype, np.integer)
996+
assert np.issubdtype(dpt.full(10, 0.3 - 2j, dtype=float).dtype, np.floating)
997+
994998

995999
def test_full_fill_array():
9961000
q = get_queue_or_skip()

0 commit comments

Comments
 (0)