Skip to content

[float8nocompile] Simplified Float8Linear implementation which only supports dynamic tensorwise scaling #1429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 18, 2024

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 17, 2024

Summary:

This PR adds a simplified implementation of Float8Linear dubbed Float8LinearNoCompile which only supports dynamic tensorwise scaling. I've used TODOs to mark the places where I need to replace the torch based logic with custom triton kernels, in order to improve eager mode performance. Once those kernels have been implemented, I'll benchmark the performance and do some profiling to identify bottlenecks and find additional optimization opportunities.

The purpose of starting with this is to start with a simple implementation which works e2e for float8 training (as shown in the test plan section below), so that as I start replacing torch logic with new triton kernels, I can use the e2e training example as a basic test to validate it's still working.

Test plan:

  • Validated this simplified implementation works e2e by running a simple example in examples/example.py:
[danvm@devgpu006.vll6 /data/users/danvm/ao/torchao/prototype/float8nocompile/examples (linear1)]$ python3 example.py 
calling convert_to_float8_nocompile_training
finished convert_to_float8_nocompile_training
step 0
step 1
step 2
step 3
step 4
step 5
step 6
step 7
step 8
step 9

Copy link

pytorch-bot bot commented Dec 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1429

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 450f140 with merge base a5a53a2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 17, 2024
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Dec 18, 2024
convert_to_float8_nocompile_training(m)
print("finished convert_to_float8_nocompile_training")

for i in range(10):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sense, I'd recommend adding a stronger numerical equivalency check (can be in a future PR):

  1. create example input
  2. create (a) reference model and (b) your model
  3. feed the example input, run backwards
  4. compare that model output, weight gradients, input gradients are equivalent across (a) and (b)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, will do this in a follow up PR.

"""

# Amax scales should always be kept as float32.
self.always_float32_buffers = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can delete, buffers are only needed for delayed scaling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes I noticed buffers weren't needed for dynamic scaling but missed this one somehow - thanks!

emulate = config.emulate
super().__init__(*args, **kwargs)

# Defines the scaling behavior of input, weight, grad_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can delete, this is only needed for supporting multiple types of scaling

self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type

self.config = config
self.is_amax_initialized = not self.config.enable_amax_init
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can delete, this is only needed for delayed scaling

def forward(self, input: torch.Tensor) -> torch.Tensor:
# TODO(danielvegamyhre): modify to support for FSDP once dependencies are implemented
output = self.forward_fp8_matmul(input)
if self.bias is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel free to not support bias at all to simplify, as our target model uses linear without bias


def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
# TODO(danielvegamyhre): replace scale calculation with triton kernel
if tensor_already_casted_to_fp8(weight):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel free to remove this to simplify

def from_float(
cls,
mod,
config: Optional[Float8LinearConfig] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: feel free to only support default settings and not give an option to modify the config

axiswise_dim: if axiswise granularity is used, defines the dim to scale across
"""
# TODO(danielvegamyhre): replace this torch implementation with custom triton kernel
if tensor_already_casted_to_fp8(hp_tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can remove to simplify

# TODO(danielvegamyhre): replace this torch implementation with custom triton kernel
if tensor_already_casted_to_fp8(hp_tensor):
return hp_tensor
scale = tensor_to_scale(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend just inlining the right code here for just the default settings instead of calling into util functions, it will be easier to compare that to handwritten kernels

scaling_granularity,
axiswise_dim,
)
return hp_tensor_and_scale_to_float8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend just inlining the right code here for just the default settings instead of calling into util functions, it will be easier to compare that to handwritten kernels

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, feel free to address changes in this PR or future PRs

@danielvegamyhre
Copy link
Contributor Author

lgtm, feel free to address changes in this PR or future PRs

Sounds good, I addressed all your comments except the extension to the test script, which I'll do in a follow up PR - thanks!

@danielvegamyhre danielvegamyhre merged commit ec64182 into pytorch:main Dec 18, 2024
18 checks passed
amdfaa pushed a commit that referenced this pull request Jan 10, 2025
…upports dynamic tensorwise scaling (#1429)

* float8nocompile: add simplified implementation of float8linear which only supports dynamic tensorwise scaling

* address comments

---------

Co-authored-by: Daniel Vega-Myhre <danvm@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants