-
Notifications
You must be signed in to change notification settings - Fork 32
Fixes incorrect output in dpctl.tensor.where strided implementation #1171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Implemented FourOffsets_StridedIndexer and FourOffsets - Implemented simplify_iteration_space_4 and simplify_iteration_four_strides - Made std::min calls in strided_iters more readable - Added test case for #1170
287846e to
c4cb96b
Compare
- Matches other contract_iter functions
|
View rendered docs @ https://intelpython.github.io/dpctl/pulls/1171/index.html |
|
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_108 ran successfully. |
c4cb96b to
98c6651
Compare
|
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_108 ran successfully. |
|
I have added a test: def test_where_invariants():
test_sh = (6, 8,)
mask = dpt.asarray(np.random.choice([True, False], size=test_sh))
p = dpt.ones(test_sh, dtype=dpt.int16)
m = dpt.full(test_sh, -1, dtype=dpt.int16)
inds_list = [(np.s_[:3], np.s_[::2],), (np.s_[::2], np.s_[::2],), (np.s_[::-1], np.s_[:],),]
for ind in inds_list:
r1 = dpt.where(mask, p, m)[ind]
r2 = dpt.where(mask[ind], p[ind], m[ind])
assert (dpt.asnumpy(r1) == dpt.asnumpy(r2)).all() |
| py::ssize_t &, | ||
| py::ssize_t &); | ||
|
|
||
| void simplify_iteration_space_4(int &, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One day, it would be nice to figure out how to make this nicer, perhaps by packing data about src1, etc. into structs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And of course, this same comment applies to simplify_iteration_space_3 and others.
|
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_109 ran successfully. |
|
Deleted rendered PR docs from intelpython.github.com/dpctl, latest should be updated shortly. 🤞 |
|
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_109 ran successfully. |
This PR implements simplify_iteration_space_4 and indexers for offsets in 4 arrays to resolve #1170.
Additionally, implements contract_iter4 and simplifies calls to std::min in strided_iters.
The issue was caused by an oversight in the mapping of elements between arrays in the strided where kernel. As the destination array was not being simplified, there was not a guarantee that the nth element of the destination corresponded to the correct element in x1 or x2.