-
Notifications
You must be signed in to change notification settings - Fork 30
Implements dpctl.tensor.where #1147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
View rendered docs @ https://intelpython.github.io/dpctl/pulls/1147/index.html |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_59 ran successfully. |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_58 ran successfully. |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_67 ran successfully. |
Significant performance gains from re-implemenation:
Where For a larger array (on scale of 10^5 elements)
|
An issue in type promotion logic: import dpctl.tensor as dpt
m = dpt.zeros(100, dtype="?")
m[::2] = True; m[1::3] = True; m[2::5] = False
x1 = dpt.ones(100, dtype="i4")
x2 = dpt.zeros(100, dtype="f4")
dpt.where(m, x1, x2) # raises ValueError The error is |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_68 ran successfully. |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_69 ran successfully. |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_74 ran successfully. |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_79 ran successfully. |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_82 ran successfully. |
Great to see coverage experience such an upward jump 👍 |
3cda10f
to
e0f08ad
Compare
// destination must be ample enough to accomodate all elements | ||
{ | ||
size_t range = | ||
static_cast<size_t>(dst_offsets.second - dst_offsets.first); | ||
if (range + 1 < static_cast<size_t>(nelems)) { | ||
throw py::value_error( | ||
"Memory addressed by the destination array can not " | ||
"accomodate all the " | ||
"array elements."); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometime down the road it would be good to modularize this check into a utility as well. Perhaps dpctl::tensor::utils::can_accomodate(nelems, dst)
. This is best done in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about this as well during refactoring. I'll make note of it.
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_83 ran successfully. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome! Thank you @ndgrigorian
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_84 ran successfully. |
5c3c5d1
to
a01db12
Compare
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_86 ran successfully. |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_89 ran successfully. |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_90 ran successfully. |
2e0bb63
to
b4967a4
Compare
- Utility functions are for finding an output type for universal and binary functions when the device of allocation lacks fp16 or fp64
- Where now outputs an F-contiguous array when all inputs are F-contiguous - Where now outputs a empty 0D array if any input is a 0D empty array - Added tests for these cases Fixed incorrect logic in where test
b4967a4
to
d50d1c6
Compare
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_93 ran successfully. |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_86 ran successfully. |
- Asymmetric dtype test to improve coverage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @ndgrigorian !
Please wait till github-action CI has finished and merge. This way artifacts would be ready to get published on |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_87 ran successfully. |
Deleted rendered PR docs from intelpython.github.com/dpctl, latest should be updated shortly. 🤞 |
Array API standard conformance tests for dpctl=0.14.3dev0=py310h76be34b_87 ran successfully. |
Closes #1120
This PR implements dpt.where function, which takes a condition array and two arrays of data, x1 and x2. The output is an array with elements from x1 where condition evaluates to true and x2 where condition is false. The output is of type determined by type promotion of x1 and x2.
i.e.,