Closed
Description
Due to some reason USM type of result array isn't coerced from USM types of inputs:
import dpctl, dpctl.tensor as dpt
dpctl.__version__
# Out: '0.18.0dev0+77.g26d34f565d'
a = dpt.ones(2, usm_type='shared')
cond = dpt.asarray([True, False], usm_type='device')
dpt.extract(cond, a).usm_type
# Out: 'shared'
dpctl.utils.get_coerced_usm_type([a.usm_type, cond.usm_type])
# Out: 'device'
Normally, for other dpctl tensor functions, the USM type of result would be 'device'
for the same inputs as in above example.
Metadata
Metadata
Assignees
Labels
No labels