-
Notifications
You must be signed in to change notification settings - Fork 267
Fix Per Row scaling for inference #2253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2253
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b10c3cc with merge base a776b1f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2253, branch: drisspg/stack/56
stack-info: PR: #2253, branch: drisspg/stack/56
stack-info: PR: #2253, branch: drisspg/stack/56
stack-info: PR: #2253, branch: drisspg/stack/56
stack-info: PR: #2253, branch: drisspg/stack/56
stack-info: PR: #2253, branch: drisspg/stack/56
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR restores per-row scaling support for float8 quantization and updates related APIs and tests.
- Add
block_size
support to quant parameter selection and dequantization for block-wise scaling - Normalize granularity in the quantization API post-init
- Fix row-wise scale handling in float8 layouts and update tests for verifying per-row scale shapes
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
quant_primitives.py | Introduce block_size argument for block-wise scale compute and update dequant logic |
quant_api.py | Normalize granularity tuple in __post_init__ |
float8_layout.py | Adjust row-wise scale transpose logic |
affine_quantized_tensor.py | Pass block_size into choose_qparams_affine_float8 |
test_affine_quantized_float.py | Parametrize and verify per-row scale shapes in tests |
Comments suppressed due to low confidence (4)
torchao/quantization/quant_primitives.py:1985
- The comment about supporting only tensorwise scaling is outdated since block-wise scaling is now implemented; please update or remove this comment to avoid confusion.
# only tensorwise scaling is supported for now:
torchao/quantization/quant_api.py:1579
- [nitpick] This normalization block appears duplicated from the class definition in the context excerpt; consider consolidating to avoid redundant logic.
activation_granularity, weight_granularity = _normalize_granularity(
torchao/dtypes/floatx/float8_layout.py:341
- Removing the
unsqueeze(-1)
before transpose changes the tensor's rank and may break row-wise scale alignment; restore the unsqueeze or adjust downstream logic to match expected dimensions.
w_scale = w_scale.T
torchao/dtypes/affine_quantized_tensor.py:465
- The variable
block_size
is used here but not defined or passed into this scope; ensureblock_size
is available or passed from the calling context.
scale = choose_qparams_affine_float8(
5d02444
to
4d7c98f
Compare
e6800d3
to
f106d46
Compare
stack-info: PR: #2253, branch: drisspg/stack/56
) -> torch.Tensor: | ||
""" | ||
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. | ||
Args: | ||
tensor (torch.Tensor): Input tensor to be quantized. | ||
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). | ||
scale_dtype (torch.dtype): Data type of the scaling factor (e.g., torch.float32). | ||
block_size (Optional[Tuple[int, ...]]): Block size for block-wise quantization. If None, tensorwise quantization is used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was support for blockwise quantization intentionally included in this PR? this looks like the quant primitive support for it but presumably other things are still needed like scaling granularity (which only is PerTensor or PerRow right now), etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This primitive, can/should be generic to block-size - which is what I did arguablly I hsould update the doc string a little, but its up to the caller to pass in the right block_size which in this case is only 2 difffernt forms today per-tensor and per-row
Stacked PRs:
Summary
In somewhere in the myriad of refactors we broke per-row scaling. This had no effect becuase dequant just so happend to work and a global scale, so we test now