Skip to content

Commit 89f8fe9

Browse files
Tweak to find_buf_dtype2
If both input types must be promoted outside of their kind, use default device data type of the kind of the result array data type. E.g. divide( int8_array, bool_array ) must return float32/float64 depending on the device capabilities, not float16.
1 parent 7453bf7 commit 89f8fe9

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

dpctl/tensor/_type_utils.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import builtins
1818

1919
import dpctl.tensor as dpt
20+
import dpctl.tensor._tensor_impl as ti
2021

2122

2223
def _all_data_types(_fp16, _fp64):
@@ -237,6 +238,20 @@ def _find_buf_dtype(arg_dtype, query_fn, sycl_dev):
237238
return None, None
238239

239240

241+
def _get_device_default_dtype(dt_kind, sycl_dev):
242+
if dt_kind == "b":
243+
return dpt.dtype(ti.default_device_bool_type(sycl_dev))
244+
elif dt_kind == "i":
245+
return dpt.dtype(ti.default_device_int_type(sycl_dev))
246+
elif dt_kind == "u":
247+
return dpt.dtype(ti.default_device_int_type(sycl_dev).upper())
248+
elif dt_kind == "f":
249+
return dpt.dtype(ti.default_device_fp_type(sycl_dev))
250+
elif dt_kind == "c":
251+
return dpt.dtype(ti.default_device_complex_type(sycl_dev))
252+
raise RuntimeError
253+
254+
240255
def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
241256
res_dt = query_fn(arg1_dtype, arg2_dtype)
242257
if res_dt:
@@ -254,7 +269,24 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
254269
if res_dt:
255270
ret_buf1_dt = None if buf1_dt == arg1_dtype else buf1_dt
256271
ret_buf2_dt = None if buf2_dt == arg2_dtype else buf2_dt
257-
return ret_buf1_dt, ret_buf2_dt, res_dt
272+
if ret_buf1_dt is None or ret_buf2_dt is None:
273+
return ret_buf1_dt, ret_buf2_dt, res_dt
274+
else:
275+
# both are being promoted, if the kind of result is
276+
# different than the kind of original input dtypes,
277+
# we must use default dtype for the resulting kind.
278+
if (res_dt.kind != arg1_dtype.kind) and (
279+
res_dt.kind != arg2_dtype.kind
280+
):
281+
default_dt = _get_device_default_dtype(
282+
res_dt.kind, sycl_dev
283+
)
284+
if res_dt == default_dt:
285+
return ret_buf1_dt, ret_buf2_dt, res_dt
286+
else:
287+
continue
288+
else:
289+
return ret_buf1_dt, ret_buf2_dt, res_dt
258290

259291
return None, None, None
260292

0 commit comments

Comments
 (0)