-
Notifications
You must be signed in to change notification settings - Fork 273
[float8nocompile] Simplified Float8Linear implementation which only supports dynamic tensorwise scaling #1429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…only supports dynamic tensorwise scaling
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1429
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 450f140 with merge base a5a53a2 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
convert_to_float8_nocompile_training(m) | ||
print("finished convert_to_float8_nocompile_training") | ||
|
||
for i in range(10): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this makes sense, I'd recommend adding a stronger numerical equivalency check (can be in a future PR):
- create example input
- create (a) reference model and (b) your model
- feed the example input, run backwards
- compare that model output, weight gradients, input gradients are equivalent across (a) and (b)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, will do this in a follow up PR.
""" | ||
|
||
# Amax scales should always be kept as float32. | ||
self.always_float32_buffers = set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can delete, buffers are only needed for delayed scaling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes I noticed buffers weren't needed for dynamic scaling but missed this one somehow - thanks!
emulate = config.emulate | ||
super().__init__(*args, **kwargs) | ||
|
||
# Defines the scaling behavior of input, weight, grad_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can delete, this is only needed for supporting multiple types of scaling
self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type | ||
|
||
self.config = config | ||
self.is_amax_initialized = not self.config.enable_amax_init |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can delete, this is only needed for delayed scaling
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
# TODO(danielvegamyhre): modify to support for FSDP once dependencies are implemented | ||
output = self.forward_fp8_matmul(input) | ||
if self.bias is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feel free to not support bias at all to simplify, as our target model uses linear without bias
|
||
def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: | ||
# TODO(danielvegamyhre): replace scale calculation with triton kernel | ||
if tensor_already_casted_to_fp8(weight): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feel free to remove this to simplify
def from_float( | ||
cls, | ||
mod, | ||
config: Optional[Float8LinearConfig] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: feel free to only support default settings and not give an option to modify the config
axiswise_dim: if axiswise granularity is used, defines the dim to scale across | ||
""" | ||
# TODO(danielvegamyhre): replace this torch implementation with custom triton kernel | ||
if tensor_already_casted_to_fp8(hp_tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can remove to simplify
# TODO(danielvegamyhre): replace this torch implementation with custom triton kernel | ||
if tensor_already_casted_to_fp8(hp_tensor): | ||
return hp_tensor | ||
scale = tensor_to_scale( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend just inlining the right code here for just the default settings instead of calling into util functions, it will be easier to compare that to handwritten kernels
scaling_granularity, | ||
axiswise_dim, | ||
) | ||
return hp_tensor_and_scale_to_float8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend just inlining the right code here for just the default settings instead of calling into util functions, it will be easier to compare that to handwritten kernels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, feel free to address changes in this PR or future PRs
Sounds good, I addressed all your comments except the extension to the test script, which I'll do in a follow up PR - thanks! |
d1e79a1
to
450f140
Compare
…upports dynamic tensorwise scaling (#1429) * float8nocompile: add simplified implementation of float8linear which only supports dynamic tensorwise scaling * address comments --------- Co-authored-by: Daniel Vega-Myhre <danvm@fb.com>
Summary:
This PR adds a simplified implementation of Float8Linear dubbed
Float8LinearNoCompile
which only supports dynamic tensorwise scaling. I've used TODOs to mark the places where I need to replace the torch based logic with custom triton kernels, in order to improve eager mode performance. Once those kernels have been implemented, I'll benchmark the performance and do some profiling to identify bottlenecks and find additional optimization opportunities.The purpose of starting with this is to start with a simple implementation which works e2e for float8 training (as shown in the test plan section below), so that as I start replacing torch logic with new triton kernels, I can use the e2e training example as a basic test to validate it's still working.
Test plan:
examples/example.py
: