Skip to content

Add a way to do power of 2 scaling #2256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: drisspg/stack/56
Choose a base branch
from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented May 23, 2025

Copy link

pytorch-bot bot commented May 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2256

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a07c9e2 with merge base a776b1f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2256, branch: drisspg/stack/57
@drisspg drisspg force-pushed the drisspg/stack/57 branch from 263096d to 6f24ba4 Compare May 23, 2025 21:37
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 23, 2025
assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported"
scale = torch.exp2(torch.round(torch.log2(scale)))

return scale.to(dtype=torch.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is a really great way to express this, this switch back is the only spooky part

Copy link
Contributor

@danielvegamyhre danielvegamyhre May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so the api to use power of 2 scales for inference would be to use float8_e8m0 as the scale dtype, which is all exponent bits so only powers of 2, is that right? This is clever but does require a step of indirection that may be confusing to users, IMO it would be better to have the API be consistent with training, where it just a config option round_scales_to_powers_of_2.

@drisspg drisspg changed the base branch from drisspg/stack/56 to main May 23, 2025 21:38
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2256, branch: drisspg/stack/57
@drisspg drisspg force-pushed the drisspg/stack/57 branch from 6f24ba4 to ecdcff8 Compare May 23, 2025 21:38
@drisspg drisspg changed the base branch from main to drisspg/stack/56 May 23, 2025 21:38
@danielvegamyhre
Copy link
Contributor

Referencing #2182 to link to issue

@drisspg drisspg changed the base branch from drisspg/stack/56 to main May 23, 2025 22:08
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2256, branch: drisspg/stack/57
@drisspg drisspg force-pushed the drisspg/stack/57 branch from ecdcff8 to ee96b19 Compare May 23, 2025 22:08
@drisspg drisspg changed the base branch from main to drisspg/stack/56 May 23, 2025 22:08
@drisspg drisspg added float8 topic: new feature Use this tag if this PR adds a new feature labels May 23, 2025
@drisspg drisspg changed the base branch from drisspg/stack/56 to main May 23, 2025 22:38
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2256, branch: drisspg/stack/57
@drisspg drisspg force-pushed the drisspg/stack/57 branch from ee96b19 to 0524b89 Compare May 23, 2025 22:38
@drisspg drisspg changed the base branch from main to drisspg/stack/56 May 23, 2025 22:38
@drisspg drisspg changed the base branch from drisspg/stack/56 to main May 23, 2025 23:12
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2256, branch: drisspg/stack/57
@drisspg drisspg force-pushed the drisspg/stack/57 branch from 0524b89 to b286999 Compare May 23, 2025 23:12
@drisspg drisspg changed the base branch from main to drisspg/stack/56 May 23, 2025 23:13
@drisspg drisspg changed the base branch from drisspg/stack/56 to main May 23, 2025 23:24
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2256, branch: drisspg/stack/57
@drisspg drisspg force-pushed the drisspg/stack/57 branch from b286999 to c170f0c Compare May 23, 2025 23:24
@drisspg drisspg changed the base branch from main to drisspg/stack/56 May 23, 2025 23:24
stack-info: PR: #2256, branch: drisspg/stack/57
@drisspg drisspg changed the base branch from drisspg/stack/56 to main May 24, 2025 00:00
@drisspg drisspg force-pushed the drisspg/stack/57 branch from c170f0c to a07c9e2 Compare May 24, 2025 00:00
@drisspg drisspg changed the base branch from main to drisspg/stack/56 May 24, 2025 00:00
@drisspg drisspg changed the base branch from drisspg/stack/56 to main May 24, 2025 00:35
@drisspg drisspg changed the base branch from main to drisspg/stack/56 May 24, 2025 00:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. float8 topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[float8] Support power of 2 scales with PerRow scales for inference
3 participants