-
Notifications
You must be signed in to change notification settings - Fork 267
primitive scale fix #2210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
primitive scale fix #2210
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2210
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit ed81130 with merge base 554cb60 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D74446877 |
@pytorchbot label "topic: not user facing" |
) Summary: Pull Request resolved: pytorch#2210 Differential Revision: D74446877
3fc617b
to
0db5fc2
Compare
This pull request was exported from Phabricator. Differential Revision: D74446877 |
0db5fc2
to
cdc082f
Compare
) Summary: Pull Request resolved: pytorch#2210 Differential Revision: D74446877
This pull request was exported from Phabricator. Differential Revision: D74446877 |
) Summary: Pull Request resolved: pytorch#2210 Differential Revision: D74446877
cdc082f
to
fa63a56
Compare
This pull request was exported from Phabricator. Differential Revision: D74446877 |
) Summary: Pull Request resolved: pytorch#2210 Differential Revision: D74446877
fa63a56
to
ed81130
Compare
This pull request was exported from Phabricator. Differential Revision: D74446877 |
@@ -948,7 +952,9 @@ def _choose_qparams_affine( | |||
scale = torch.clamp(scale, min=eps) | |||
else: | |||
assert mapping_type == MappingType.ASYMMETRIC.name | |||
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) | |||
scale = (max_val_pos - min_val_neg) / torch.tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if you did:
(max_val_pos - min_val_neg) / (quant_max - quant_min).to(torch.float32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Casting to float32
doesn't help the discrepancy on CPU vs GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh sorry I see that quant_max and quant_mins are ints, why is this better than the existing? Maybe a test might be helpful
import torch
from transformer_nuggets.utils.tracing import LoggingMode
a = torch.randn(2, 3, device="cuda")
with LoggingMode():
out = a / float(32 - 12)
print(out)
with LoggingMode():
out = a / torch.tensor(float(32 - 12), dtype=a.dtype, device=a.device)
print(out)
Produces:
$1: f32[2, 3] = aten.div.Tensor($0, 20.0)
tensor([[-0.0313, -0.0028, -0.0344],
[ 0.0299, 0.0099, -0.0426]], device='cuda:0')
$0: f32[] = aten.lift_fresh.default($0)
$2: f32[2, 3] = aten.div.Tensor($1, $0)
tensor([[-0.0313, -0.0028, -0.0344],
[ 0.0299, 0.0099, -0.0426]], device='cuda:0')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Driss, the issue is "float" as in python float is actually float64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this isn't about fixing a device but about fixing the dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I think it's fine to merge as long as it doesn't break existing tests, there are some flexibility in these quant primitive ops since we didn't really define these very precisely
Differential Revision: D74446877