-
Notifications
You must be signed in to change notification settings - Fork 32
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
If shape of condition parameter is not equal to one of shapes of x or y , dpctl.tensor.where returns an incorrect result.
For example
cond_shape = (4,) , x_shape = (4,), y_shape=(4,) - correct
cond_shape = (2, 3) , x_shape = (2, 3), y_shape=(3,) - correct
cond_shape = (4,) , x_shape = (2, 3), y_shape=(2, 3) - incorrect
#NUMPY
np_cond = [False, True, True, False]
np_x = np.ones((3,4))
np_y = np.zeros((3,4))
#DPT
dpt_cond = dpt.asarray(np_cond)
dpt_x = dpt.asarray(np_x)
dpt_y = dpt.asarray(np_y)
np.where(np_cond, np_x, np_y)
>> array([[0., 1., 1., 0.],
[0., 1., 1., 0.],
[0., 1., 1., 0.]])
dpt.where(dpt_cond, dpt_x, dpt_y)
>> usm_ndarray([[0., 0., 0., 1.],
[1., 1., 1., 1.],
[1., 0., 0., 0.]], dtype=float32)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working