-
Notifications
You must be signed in to change notification settings - Fork 273
[float8] Allow specifying arbitrary dtype for each tensor #1326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/lw/2/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1326
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 97c9983 with merge base 1a0dbf1 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -62,6 +62,7 @@ class CastConfig: | |||
scaling_type: ScalingType = ScalingType.DYNAMIC | |||
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE | |||
static_scale: Optional[torch.Tensor] = None | |||
dtype: Optional[torch.dtype] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
- can we add a comment on what this is used for, and that
None
means the default e4m3|e5m2 value will be used? - optional - thoughts about naming this in a more specific way such as
target_dtype
,lowp_dtype
, etc?dtype
is a bit ambiguous across torchao unfortunately :(
@@ -343,12 +367,14 @@ def recipe_name_to_linear_config( | |||
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) | |||
|
|||
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise | |||
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) | |||
cc_go = CastConfig( | |||
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe we can also add some context in the comments on L353:L363 that it also uses e4m3 for grads?
NoopFwToFloat8E5M2BwDelayed, | ||
NoopFwToFloat8E5M2BwDynamic, | ||
NoopFwToFloat8E5M2BwStatic, | ||
NoopFwToFloat8BwDelayed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for updating these!
@@ -303,13 +311,16 @@ def inner_func(): | |||
|
|||
# Calculate the new scales from the updated history stacks | |||
new_input_scales = amax_history_to_scale_stack( | |||
fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe | |||
fp8_input_amax_history_stack, input_dtype, x_dtype, scale_fn_recipe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will likely have to rebase on top of #1329 which changed this line
@@ -62,6 +62,7 @@ class CastConfig: | |||
scaling_type: ScalingType = ScalingType.DYNAMIC | |||
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE | |||
static_scale: Optional[torch.Tensor] = None | |||
dtype: Optional[torch.dtype] = None | |||
|
|||
def short_str(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also add the dtype here, so it appears when we print an instance of Float8Linear
? Float8Linear.__extra_repr__
calls this method.
This is great! LGTM, had some comments but all are pretty nitty. CI is green - ship it! |
Superseded by #1378 |
Stack from ghstack (oldest at bottom):