-
Notifications
You must be signed in to change notification settings - Fork 256
Support MX4 E3M0 format and add stochastic rounding #477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/477
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -21,6 +21,8 @@ | |||
|
|||
from torchao.prototype.mx_formats.constants import ( | |||
DTYPE_FP4, | |||
DTYPE_FP4_E2M1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for all the changes in this file, can you rebase past #363 ?
and in the code after that PR, perhaps we can add stochastic rounding option to _f32_to_fpx_unpacked
?
for rounding mode, how about something like this instead of a boolean?
class RoundingMode(enum.Enum):
TIE_TO_EVEN = auto() # default
STOCHASTIC = auto() # added in this PR
def foo(..., rounding_mode=RoundingMode.TIE_TO_EVEN, ...): ...
@pytest.mark.parametrize("device", ["cuda", "cpu"]) | ||
@pytest.mark.parametrize("sign", [1, -1]) | ||
@pytest.mark.parametrize("use_stochastic_rounding", [False, True]) | ||
def test_overflow_cast(hp_dtype, device, sign, use_stochastic_rounding): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add these tests to test/prototype/mx_formats/test_custom_cast.py to keep the testing of MX numerics in one place?
thanks for adding this! left some comments, mostly on rebasing past https://github.com/pytorch/ao/pull/363/files and code style |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The E3M0 numerics implementation looks good to me.
denormal_x = denormal_x.view(torch.float) | ||
|
||
# adjust the denormal values back | ||
denormal_x -= min_normal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SR code up to this line looks good to me.
57165ff
to
60fe4c3
Compare
60fe4c3
to
e202e6e
Compare
55dfe1f
to
58a9f01
Compare
58a9f01
to
45520d2
Compare
No description provided.