Skip to content

Fix Per Row scaling for inference #2253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Fix Per Row scaling for inference #2253

wants to merge 1 commit into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented May 23, 2025

Stacked PRs:


Summary

In somewhere in the myriad of refactors we broke per-row scaling. This had no effect becuase dequant just so happend to work and a global scale, so we test now

Copy link

pytorch-bot bot commented May 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2253

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b10c3cc with merge base a776b1f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from 8e43100 to ae92547 Compare May 23, 2025 18:32
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 23, 2025
@drisspg drisspg added bug Something isn't working high priority topic: bug fix Use this tag for PRs that fix bugs labels May 23, 2025
@drisspg drisspg marked this pull request as draft May 23, 2025 18:44
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from ae92547 to bd0df40 Compare May 23, 2025 20:16
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from bd0df40 to a3ec2a2 Compare May 23, 2025 21:02
@drisspg drisspg marked this pull request as ready for review May 23, 2025 21:02
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from a3ec2a2 to a2f2f09 Compare May 23, 2025 21:11
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from a2f2f09 to 83640b7 Compare May 23, 2025 21:14
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from 83640b7 to f550778 Compare May 23, 2025 21:19
@drisspg drisspg force-pushed the drisspg/stack/56 branch from f550778 to 5fa37cb Compare May 23, 2025 21:38
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR restores per-row scaling support for float8 quantization and updates related APIs and tests.

  • Add block_size support to quant parameter selection and dequantization for block-wise scaling
  • Normalize granularity in the quantization API post-init
  • Fix row-wise scale handling in float8 layouts and update tests for verifying per-row scale shapes

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
quant_primitives.py Introduce block_size argument for block-wise scale compute and update dequant logic
quant_api.py Normalize granularity tuple in __post_init__
float8_layout.py Adjust row-wise scale transpose logic
affine_quantized_tensor.py Pass block_size into choose_qparams_affine_float8
test_affine_quantized_float.py Parametrize and verify per-row scale shapes in tests
Comments suppressed due to low confidence (4)

torchao/quantization/quant_primitives.py:1985

  • The comment about supporting only tensorwise scaling is outdated since block-wise scaling is now implemented; please update or remove this comment to avoid confusion.
# only tensorwise scaling is supported for now:

torchao/quantization/quant_api.py:1579

  • [nitpick] This normalization block appears duplicated from the class definition in the context excerpt; consider consolidating to avoid redundant logic.
activation_granularity, weight_granularity = _normalize_granularity(

torchao/dtypes/floatx/float8_layout.py:341

  • Removing the unsqueeze(-1) before transpose changes the tensor's rank and may break row-wise scale alignment; restore the unsqueeze or adjust downstream logic to match expected dimensions.
w_scale = w_scale.T

torchao/dtypes/affine_quantized_tensor.py:465

  • The variable block_size is used here but not defined or passed into this scope; ensure block_size is available or passed from the calling context.
scale = choose_qparams_affine_float8(

@drisspg drisspg force-pushed the drisspg/stack/56 branch 2 times, most recently from 5d02444 to 4d7c98f Compare May 23, 2025 22:38
@drisspg drisspg force-pushed the drisspg/stack/56 branch 2 times, most recently from e6800d3 to f106d46 Compare May 23, 2025 23:24
@drisspg drisspg requested a review from danielvegamyhre May 23, 2025 23:28
stack-info: PR: #2253, branch: drisspg/stack/56
) -> torch.Tensor:
"""
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
Args:
tensor (torch.Tensor): Input tensor to be quantized.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
scale_dtype (torch.dtype): Data type of the scaling factor (e.g., torch.float32).
block_size (Optional[Tuple[int, ...]]): Block size for block-wise quantization. If None, tensorwise quantization is used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was support for blockwise quantization intentionally included in this PR? this looks like the quant primitive support for it but presumably other things are still needed like scaling granularity (which only is PerTensor or PerRow right now), etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This primitive, can/should be generic to block-size - which is what I did arguablly I hsould update the doc string a little, but its up to the caller to pass in the right block_size which in this case is only 2 difffernt forms today per-tensor and per-row

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. high priority topic: bug fix Use this tag for PRs that fix bugs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants