Skip to content

dpct.tensor.where wrong results with different shapes  #1170

@vlad-perevezentsev

Description

@vlad-perevezentsev

If shape of condition parameter is not equal to one of shapes of x or y , dpctl.tensor.where returns an incorrect result.
For example
cond_shape = (4,) , x_shape = (4,), y_shape=(4,) - correct
cond_shape = (2, 3) , x_shape = (2, 3), y_shape=(3,) - correct
cond_shape = (4,) , x_shape = (2, 3), y_shape=(2, 3) - incorrect

#NUMPY
np_cond = [False, True, True, False]
np_x = np.ones((3,4))
np_y = np.zeros((3,4))

#DPT 
dpt_cond = dpt.asarray(np_cond)
dpt_x = dpt.asarray(np_x)
dpt_y = dpt.asarray(np_y)

np.where(np_cond, np_x, np_y)
>> array([[0., 1., 1., 0.],
          [0., 1., 1., 0.],
          [0., 1., 1., 0.]])

dpt.where(dpt_cond, dpt_x, dpt_y)
>> usm_ndarray([[0., 0., 0., 1.],
                [1., 1., 1., 1.],
                [1., 0., 0., 0.]], dtype=float32)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions