Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exhaustive testing to ValueRanges, fix bugs #94939

Closed
wants to merge 4 commits into from

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented Feb 15, 2023

Stack from ghstack (oldest at bottom):

Since I didn't want to deal with nondeterministic tests, I went the exhaustive testing route for a fixed list of constants to look at. The tests generate random ranges, propagate the range through the function, and then pick elements in the range and check that the result on the operation is in the resulting range. This caught bugs in log, sqrt and pow.

My resolution for pow was a little special, because I had trouble figuring out the correct semantics under all inputs domains. Instead, I picked two input domains (pow on two point ranges, and pow where exponent is known) and only implemented those. Everything else we give up. I think this is unlikely to affect perf.

Signed-off-by: Edward Z. Yang ezyang@meta.com

cc @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 15, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94939

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bf8f3d7:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request Feb 15, 2023
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 6a3d0236bcf4ad1e8b8070a519c2f802d9105c62
Pull Request resolved: #94939
@albanD albanD removed their request for review February 15, 2023 21:48
@ezyang ezyang changed the title Add fuzz testing to ValueRanges, fix bugs Add exhaustive testing to ValueRanges, fix bugs Feb 15, 2023
Comment on lines +291 to +292
# This is fairly difficult to analyze, so give up for anything
# complicated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I see we were missing a couple cases here, yep.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future reference, it's probably enough to analyse the signs of a and b separately and have those 4 cases, but yeah, probably not worth the effort.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PRs welcome! The problem I had was that you also have different regimes for a < 1 vs > 1 vs = 1. PAIN

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll submit a PR once this one's merged.

torch/utils/_sympy/value_ranges.py Outdated Show resolved Hide resolved
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice !

@@ -119,15 +122,45 @@ def ceil(x):
return math.ceil(x)


def valid_unary(fn, v):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also run the reference first and see if it throws, then skip

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still good to avoid trying to eval something silly like 2 ** 10000 which sympy will do symbolically lol

test/test_value_ranges.py Show resolved Hide resolved
torch/utils/_sympy/value_ranges.py Show resolved Hide resolved
@ezyang
Copy link
Contributor Author

ezyang commented Feb 16, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 16, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Since I didn't want to deal with nondeterministic tests, I went the exhaustive testing route for a fixed list of constants to look at. The tests generate random ranges, propagate the range through the function, and then pick elements in the range and check that the result on the operation is in the resulting range. This caught bugs in log, sqrt and pow.

My resolution for pow was a little special, because I had trouble figuring out the correct semantics under all inputs domains. Instead, I picked two input domains (pow on two point ranges, and pow where exponent is known) and only implemented those. Everything else we give up. I think this is unlikely to affect perf.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@ezyang
Copy link
Contributor Author

ezyang commented Feb 16, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Since I didn't want to deal with nondeterministic tests, I went the exhaustive testing route for a fixed list of constants to look at. The tests generate random ranges, propagate the range through the function, and then pick elements in the range and check that the result on the operation is in the resulting range. This caught bugs in log, sqrt and pow.

My resolution for pow was a little special, because I had trouble figuring out the correct semantics under all inputs domains. Instead, I picked two input domains (pow on two point ranges, and pow where exponent is known) and only implemented those. Everything else we give up. I think this is unlikely to affect perf.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
@github-actions github-actions bot added the module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration label Feb 16, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

Copy link
Collaborator

@nunoplopes nunoplopes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Since I didn't want to deal with nondeterministic tests, I went the exhaustive testing route for a fixed list of constants to look at. The tests generate random ranges, propagate the range through the function, and then pick elements in the range and check that the result on the operation is in the resulting range. This caught bugs in log, sqrt and pow.

My resolution for pow was a little special, because I had trouble figuring out the correct semantics under all inputs domains. Instead, I picked two input domains (pow on two point ranges, and pow where exponent is known) and only implemented those. Everything else we give up. I think this is unlikely to affect perf.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen

[ghstack-poisoned]
@yanbing-j
Copy link
Collaborator

Hi @ezyang , could you please double check the code change of third_party/ideep? I suppose the change is unnecessary, could you please reland this PR then? Thanks!

@ezyang
Copy link
Contributor Author

ezyang commented Feb 28, 2023

It's not intentional. You won't be able to revert this PR, please submit a forward fix PR.

@yanbing-j
Copy link
Collaborator

Please review the forward fix PR #95688. Thanks!

pytorchmergebot pushed a commit that referenced this pull request Feb 28, 2023
### Description
This PR is to update ideep to add primitive cache in order to speed up ARM's PyTorch workloads.
Reland #94719, which is unintentional reverted by #94939 (comment).
Fixes #94264.

### Performance test
Use TorchBench test in ICX with 40 cores
Intel OpenMP & jemalloc were preloaded
![image](https://user-images.githubusercontent.com/61222868/221760391-fb6cbabe-6d88-4155-b216-348e718e68b9.png)

Pull Request resolved: #95688
Approved by: https://github.com/ezyang
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 2, 2023
### Description
This PR is to update ideep to add primitive cache in order to speed up ARM's PyTorch workloads.
Reland pytorch/pytorch#94719, which is unintentional reverted by pytorch/pytorch#94939 (comment).
Fixes pytorch/pytorch#94264.

### Performance test
Use TorchBench test in ICX with 40 cores
Intel OpenMP & jemalloc were preloaded
![image](https://user-images.githubusercontent.com/61222868/221760391-fb6cbabe-6d88-4155-b216-348e718e68b9.png)

Pull Request resolved: pytorch/pytorch#95688
Approved by: https://github.com/ezyang
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
### Description
This PR is to update ideep to add primitive cache in order to speed up ARM's PyTorch workloads.
Reland pytorch/pytorch#94719, which is unintentional reverted by pytorch/pytorch#94939 (comment).
Fixes pytorch/pytorch#94264.

### Performance test
Use TorchBench test in ICX with 40 cores
Intel OpenMP & jemalloc were preloaded
![image](https://user-images.githubusercontent.com/61222868/221760391-fb6cbabe-6d88-4155-b216-348e718e68b9.png)

Pull Request resolved: pytorch/pytorch#95688
Approved by: https://github.com/ezyang
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
### Description
This PR is to update ideep to add primitive cache in order to speed up ARM's PyTorch workloads.
Reland pytorch/pytorch#94719, which is unintentional reverted by pytorch/pytorch#94939 (comment).
Fixes pytorch/pytorch#94264.

### Performance test
Use TorchBench test in ICX with 40 cores
Intel OpenMP & jemalloc were preloaded
![image](https://user-images.githubusercontent.com/61222868/221760391-fb6cbabe-6d88-4155-b216-348e718e68b9.png)

Pull Request resolved: pytorch/pytorch#95688
Approved by: https://github.com/ezyang
pruthvistony added a commit to ROCm/pytorch that referenced this pull request May 2, 2023
@facebook-github-bot facebook-github-bot deleted the gh/ezyang/1819/head branch June 8, 2023 16:50
jhavukainen pushed a commit to kulinseth/pytorch that referenced this pull request Mar 15, 2024
Since I didn't want to deal with nondeterministic tests, I went the exhaustive testing route for a fixed list of constants to look at. The tests generate random ranges, propagate the range through the function, and then pick elements in the range and check that the result on the operation is in the resulting range. This caught bugs in log, sqrt and pow.

My resolution for pow was a little special, because I had trouble figuring out the correct semantics under all inputs domains. Instead, I picked two input domains (pow on two point ranges, and pow where exponent is known) and only implemented those. Everything else we give up. I think this is unlikely to affect perf.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#94939
Approved by: https://github.com/lezcano, https://github.com/eellison, https://github.com/nunoplopes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration release notes: inductor topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants