Skip to content

dpctl.tensor.where output preserves memory order of inputs #1342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 18, 2023

Conversation

ndgrigorian
Copy link
Collaborator

This PR adjusts the behavior of dpctl.tensor.where to preserve the memory layout of its inputs. This improves the access pattern of the kernel.

Performance of old behavior:

In [1]: import dpctl.tensor as dpt, numpy as np

In [2]: dt = dpt.int32

In [3]: sh = (8160, 8160)

In [4]: ar1 = dpt.ones(sh, dtype=dt, order="C")[:4080, ::-2].mT

In [5]: ar2 = dpt.zeros(sh, dtype=dt, order="C")[:4080, ::-2].mT

In [6]: condition = dpt.zeros(sh, dtype=dt, order="C")[:4080, ::-2].mT

In [7]: res = dpt.where(condition, ar1, ar2)
CPU times: user 143 ms, sys: 53.4 ms, total: 196 ms
Wall time: 220 ms

In [8]: %time res = dpt.where(condition, ar1, ar2)
CPU times: user 48.3 ms, sys: 67.6 ms, total: 116 ms
Wall time: 119 ms

In [9]: %time res = dpt.where(condition, ar1, ar2)
CPU times: user 81.5 ms, sys: 31.1 ms, total: 113 ms
Wall time: 114 ms

New behavior:

In [9]: %time res = dpt.where(condition, ar1, ar2)
CPU times: user 26.9 ms, sys: 198 µs, total: 27.1 ms
Wall time: 27.8 ms

In [10]: %time res = dpt.where(condition, ar1, ar2)
CPU times: user 16.1 ms, sys: 13.8 ms, total: 29.9 ms
Wall time: 30.9 ms

In [11]: %time res = dpt.where(condition, ar1, ar2)
CPU times: user 23.3 ms, sys: 1.43 ms, total: 24.7 ms
Wall time: 26.7 ms

which is an improvement of ~4x in certain cases.

_empty_like_triple_orderK is introduced for this purpose and a test was added to ensure that the output strides are as expected.

  • Have you provided a meaningful PR description?
  • Have you added a test, reproducer or referred to an issue with a reproducer?
  • Have you tested your changes locally for CPU and GPU devices?
  • Have you made sure that new changes do not introduce compiler warnings?
  • Have you checked performance impact of proposed changes?
  • If this PR is a work in progress, are you opening the PR as a draft?

@github-actions
Copy link

@coveralls
Copy link
Collaborator

coveralls commented Aug 14, 2023

Coverage Status

coverage: 85.072% (+0.03%) from 85.041% when pulling 873b0b6 on where-order-K into bd996b5 on master.

@github-actions
Copy link

Array API standard conformance tests for dpctl=0.14.6dev2=py310h7bf5fec_5 ran successfully.
Passed: 913
Failed: 87
Skipped: 119

- Now when operands are cast, stride simplification can still be performed on non-C contiguous inputs
- Implements _empty_like_triple_orderK to allocate output of where
- Now calls _empty_like_pair_orderK when two arrays are of equal shape and larger than the third
@github-actions
Copy link

Array API standard conformance tests for dpctl=0.14.6dev3=py310ha25a700_22 ran successfully.
Passed: 913
Failed: 87
Skipped: 119

- Dimensions of size 1 are effectively disregarded in sorting
@github-actions
Copy link

Array API standard conformance tests for dpctl=0.14.6dev3=py310ha25a700_23 ran successfully.
Passed: 913
Failed: 87
Skipped: 119

Copy link
Contributor

@oleksandr-pavlyk oleksandr-pavlyk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good to go in @ndgrigorian. I had a small nitpick, but it just a matter of style.

@github-actions
Copy link

Array API standard conformance tests for dpctl=0.14.6dev3=py310ha25a700_24 ran successfully.
Passed: 913
Failed: 87
Skipped: 119

@ndgrigorian ndgrigorian merged commit 852f4b1 into master Aug 18, 2023
@github-actions
Copy link

Deleted rendered PR docs from intelpython.github.com/dpctl, latest should be updated shortly. 🤞

@ndgrigorian ndgrigorian deleted the where-order-K branch August 18, 2023 20:47
@github-actions
Copy link

Array API standard conformance tests for dpctl=0.14.6dev3=py310ha25a700_43 ran successfully.
Passed: 913
Failed: 87
Skipped: 119

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants