Skip to content

add the torch.float8_e8m0fnu dtype to PyTorch #147466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Feb 19, 2025

Summary:

Continuing the work from #146427

Adds the torch.float8_e8m0fnu dtype to PyTorch, as detailed in
#146414 . Please see the issue for a detailed definition of the format. Example of basic functionality:

import torch

# round trip
x0 = torch.randn(4, 4, dtype=torch.float32)
x1 = x0.to(torch.float8_e8m0fnu)  # RNE rounding
x2 = x1.to(torch.float32)  # 2 ** exponent

# creation with empty
x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu)

# printing
print(x0)

Done in this PR:

  • numerical correctness
  • op coverage (except for torch._scaled_mm): create tensor, cast to/from float32
  • printing a tensor works

For future PRs:

  • performance optimizations for casting
  • torch._scaled_mm
  • PT2
  • various cleanups (detailed in comments with issue numbers)

Test Plan:

pytest test/quantization/core/experimental/test_float8.py -s

Reviewers:

Subscribers:

Tasks:

Tags:

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

Summary:

Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in
#146414

Not ready for review yet.

Test Plan:

```
pytest test/quantization/core/experimental/test_float8.py -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-comment-id: 2634707334
Copy link

pytorch-bot bot commented Feb 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147466

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d635775 with merge base 303ad19 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: quantization release notes category labels Feb 19, 2025
@vkuzo vkuzo changed the title add the torch.float8_e8m0fnu` dtype to PyTorch add the torch.float8_e8m0fnu dtype to PyTorch Feb 19, 2025
@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 19, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 19, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda11_8-build / build

Details for Dev Infra team Raised by workflow job

@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 20, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@henrylhtsang
Copy link
Contributor

henrylhtsang commented Feb 20, 2025

false alarm my bad

Raymo111 pushed a commit that referenced this pull request Feb 20, 2025
Summary:

Continuing the work from #146427

Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in
#146414 . Please see the issue for a detailed definition of the format.  Example of basic functionality:

```python
import torch

# round trip
x0 = torch.randn(4, 4, dtype=torch.float32)
x1 = x0.to(torch.float8_e8m0fnu)  # RNE rounding
x2 = x1.to(torch.float32)  # 2 ** exponent

# creation with empty
x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu)

# printing
print(x0)
```

Done in this PR:
* numerical correctness
* op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32
* printing a tensor works

For future PRs:
* performance optimizations for casting
* torch._scaled_mm
* PT2
* various cleanups (detailed in comments with issue numbers)

Test Plan:

```
pytest test/quantization/core/experimental/test_float8.py -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: #147466
Approved by: https://github.com/drisspg
pytorch-bot bot pushed a commit that referenced this pull request Feb 24, 2025
Summary:

Continuing the work from #146427

Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in
#146414 . Please see the issue for a detailed definition of the format.  Example of basic functionality:

```python
import torch

# round trip
x0 = torch.randn(4, 4, dtype=torch.float32)
x1 = x0.to(torch.float8_e8m0fnu)  # RNE rounding
x2 = x1.to(torch.float32)  # 2 ** exponent

# creation with empty
x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu)

# printing
print(x0)
```

Done in this PR:
* numerical correctness
* op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32
* printing a tensor works

For future PRs:
* performance optimizations for casting
* torch._scaled_mm
* PT2
* various cleanups (detailed in comments with issue numbers)

Test Plan:

```
pytest test/quantization/core/experimental/test_float8.py -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: #147466
Approved by: https://github.com/drisspg
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
Summary:

Continuing the work from pytorch#146427

Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in
pytorch#146414 . Please see the issue for a detailed definition of the format.  Example of basic functionality:

```python
import torch

# round trip
x0 = torch.randn(4, 4, dtype=torch.float32)
x1 = x0.to(torch.float8_e8m0fnu)  # RNE rounding
x2 = x1.to(torch.float32)  # 2 ** exponent

# creation with empty
x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu)

# printing
print(x0)
```

Done in this PR:
* numerical correctness
* op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32
* printing a tensor works

For future PRs:
* performance optimizations for casting
* torch._scaled_mm
* PT2
* various cleanups (detailed in comments with issue numbers)

Test Plan:

```
pytest test/quantization/core/experimental/test_float8.py -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: pytorch#147466
Approved by: https://github.com/drisspg
mengfei25 added a commit to mengfei25/pytorch that referenced this pull request Mar 6, 2025
jianyizh added a commit to jianyizh/pytorch that referenced this pull request Mar 6, 2025
@github-actions github-actions bot deleted the 20250219_e8m0_intermediate branch March 27, 2025 02:11
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Apr 3, 2025
Summary:

Continuing the work from pytorch#146427

Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in
pytorch#146414 . Please see the issue for a detailed definition of the format.  Example of basic functionality:

```python
import torch

# round trip
x0 = torch.randn(4, 4, dtype=torch.float32)
x1 = x0.to(torch.float8_e8m0fnu)  # RNE rounding
x2 = x1.to(torch.float32)  # 2 ** exponent

# creation with empty
x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu)

# printing
print(x0)
```

Done in this PR:
* numerical correctness
* op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32
* printing a tensor works

For future PRs:
* performance optimizations for casting
* torch._scaled_mm
* PT2
* various cleanups (detailed in comments with issue numbers)

Test Plan:

```
pytest test/quantization/core/experimental/test_float8.py -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: pytorch#147466
Approved by: https://github.com/drisspg
@yiakwy-xpu-ml-framework-team
Copy link

yiakwy-xpu-ml-framework-team commented Apr 23, 2025

Hi @vkuzo are you still working on the problem ? I think the way you used MXFP8_E8M0_FNU could be discussed.

The number is used inside group_quantize function each warp extract 32 fp3_exponent from 32 fp32. So the ocp_fp8e8m0fnu_from_fp32 is simply equal to extract fp32 exponent.

No mantssa and speical number should be taken careful because exponent itself is unsigned. We don't need to care about it.

Fp32 -> fp8 + mxfp8_e8m0_fnu

The second problem is that this mxfp8_e8m0_fnu will be shared by a group consecutive 32 (group) elements (no exponent, only mantissa w/wo implict-1). (because mantissa multiply the exponent is the fp32).

Note fp8 scalar multiple a fp8 data type is as easy as fp32 = fp8_scale << fp32_t::M | (fp8 & fp32_t::INT32_M_MASK));

Here is my implementation :

template<>
HOST_DEVICE_INLINE OutType ocp_fp8e8m0fnu_from_fp32(float fval) {
    using fp32_t = Float;
    using fp8_t = Float8_E8M0_FNU;
    using fp8_storage_t = Float8_E8M0_FNU::Datum;

    union {
        float fval;
        int32_t i32val;
        uint32_t ui32val;
    } val;

    val.fval = fval;

    fp8_storage_t ui8val = (val.i32val & fp32_t::INT32_E_MASK) >> fp32_t::M;
    return fp8_t::from_bits(ui8val.ui8val);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) release notes: quantization release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants