Skip to content

refine fp32 precision api #125888

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 83 commits into from
Closed

Conversation

zhuhaozhe
Copy link
Collaborator

@zhuhaozhe zhuhaozhe commented May 10, 2024

Based on the conversation, we plan to drop the "highest, high, medium" to represent fp32 internal computation data types . Instead, we will directly use the algorithm to represent it.

Design Choice: Directly use algorithms name like "TF32", "BF16".

Pros

  • The names are more informative. 'tf32' is more informative than a simple "high".
  • Easier to extend new algorithm like tf32x3

Cons

  • "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them.

We provide a layered structure for backends/operators.

('f32' is short for 'fp32_precision')
image

We provide 3 fp32 compute precision can be set:

  • "ieee": Not allowed to use any other internal computation data types .
  • "tf32": Allowed to use tf32 as internal computation data types.
  • "bf16": Allowed to use bf16 as internal computation data types.
  • "none": Precision's are not set. Can be override by its father node.

Overriding Precision Settings

Child node can be override by its father node if it is set to default.
For current default settings:

backend = generic, op = all, precision setting = none
    backend = cuda, op = all, precision setting = none
        backend = cuda, op = conv, precision setting = tf32
        backend = cuda, op = rnn, precision setting = tf32
        backend = cuda, op = matmul, precision setting = none
    backend = matmul, op = all, precision setting = none
        backend = matmul, op = conv, precision setting = none
        backend = matmul, op = rnn, precision setting = none
        backend = matmul, op = matmul, precision setting = none
  • If the user set torch.backends.mkldnn.fp32_precision="bf16", his child nodes torch.backends.mkldnn.matmul.fp32_precision / torch.backends.mkldnn.conv.fp32_precision / torch.backends.mkldnn.rnn.fp32_precision will also be override to "bf16".
  • If the user set torch.backends.fp32_precision="bf16", torch.backends.mkldnn.fp32_precision and his child nodes will also we override to "bf16".

Backward Compatible

Since new API allow user to have more fine-grained control. There will be some conflict. For example, previous torch.backends.cudnn.allow_tf32 are not enough to represent the status for torch.backends.cudnn.rnn.fp32_precision="ieee" and torch.backends.cudnn.conv.fp32_precision="tf32". Therefore, our goal for backward compatible is

  • If the user only uses previous APIs, it will work as previous expectations.
  • If the user use new API to change the status to an un-representable status for old API, and try to access the status by old API. We will raise Runtime Error and point the document for user.

Test Plan

python test/test_cuda.py -k test_fp32_precision_with_tf32
python test/test_cuda.py -k test_fp32_precision_with_float32_matmul_precision
python test/test_cuda.py -k test_invalid_status_for_legacy_api
python test/test_mkldnn.py -k test_mlkdnn_get_set
python test/test_mkldnn.py -k test_generic_precision
python test/test_mkldnn.py -k test_invalid
python test/test_mkldnn.py -k test_default_use_parent

Stack from ghstack (oldest at bottom):

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @voznesenskym @penguinwu @EikanWang @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @rec

@zhuhaozhe zhuhaozhe requested a review from eqy as a code owner May 10, 2024 01:20
Copy link

pytorch-bot bot commented May 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125888

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (6 Unrelated Failures)

As of commit 467fe0e with merge base 78ee2ee (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label May 10, 2024
@zhuhaozhe zhuhaozhe changed the title refine fp32 precision api [WIP] refine fp32 precision api May 10, 2024
@zhuhaozhe zhuhaozhe marked this pull request as draft May 10, 2024 01:21
@zhuhaozhe zhuhaozhe added the ciflow/trunk Trigger trunk jobs on your pull request label May 10, 2024
[ghstack-poisoned]
zhuhaozhe added a commit that referenced this pull request May 10, 2024
ghstack-source-id: e9d5141
Pull Request resolved: #125888
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
zhuhaozhe added a commit that referenced this pull request May 10, 2024
ghstack-source-id: 73f3cfd
Pull Request resolved: #125888
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
zhuhaozhe added a commit that referenced this pull request May 10, 2024
ghstack-source-id: a4c02dc
Pull Request resolved: #125888
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@eqy
Copy link
Collaborator

eqy commented May 10, 2024

CC @mruberry who authored #76440 , @ptrblck
as it seems like this PR abandons "medium, high, highest"

@zhuhaozhe
Copy link
Collaborator Author

zhuhaozhe commented May 11, 2024

CC @mruberry who authored #76440 , @ptrblck as it seems like this PR abandons "medium, high, highest"

Hi, @eqy. This is a WIP draft PR based on conversation here #121791. It request your review automatically
I will summarize the design options asap, thanks.

@zhuhaozhe zhuhaozhe removed the request for review from eqy May 11, 2024 01:52
Based on the [conversation](#121791), we plan to drop the "highest, high, medium" to represent fp32  internal computation data types . Instead, we will directly use the algorithm to represent it.

### Design Choice: Directly use algorithms name like "TF32", "BF16". 
#### Pros
 - The names are more informative. 'tf32' is more informative than a simple "HIGH".
 - Easier to extend new algorithm like `tf32x3` 
#### Cons
 - "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them.








### We provide a layered structure for backends/operators.
('f32' is short for 'fp32_precision')
![image](https://github.com/pytorch/pytorch/assets/54701539/9cddf275-071c-4f69-a5ee-1540f78ac7f4)




### We provide 4 fp32 compute precision can be set:
 - **"ieee"**: computation will happened at pure FP32 level, BF16 an TF32 are not allowed.
 - **"tf32"**: allowed to use tf32 as internal computation data types.
 - **"bf16"**: allowed to use bf16 as internal computation data types.
 - **"default"**: no specific precisions are set. We will search it's parent precision under layered structure.






cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
Based on the [conversation](#121791), we plan to drop the "highest, high, medium" to represent fp32  internal computation data types . Instead, we will directly use the algorithm to represent it.

### Design Choice: Directly use algorithms name like "TF32", "BF16". 
#### Pros
 - The names are more informative. 'tf32' is more informative than a simple "HIGH".
 - Easier to extend new algorithm like `tf32x3` 
#### Cons
 - "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them.

### We provide a layered structure for backends/operators.
('f32' is short for 'fp32_precision')
![image](https://github.com/pytorch/pytorch/assets/54701539/9cddf275-071c-4f69-a5ee-1540f78ac7f4)

### We provide 4 fp32 compute precision can be set:
 - **"ieee"**: computation will happened at pure FP32 level, BF16 an TF32 are not allowed.
 - **"tf32"**: allowed to use tf32 as internal computation data types.
 - **"bf16"**: allowed to use bf16 as internal computation data types.
 - **"default"**: no specific precisions are set. We will search it's parent precision under layered structure.

### Examples
```python
# change top level fp32_precision from default value "ieee" to "tf32"
>>> torch.backends.fp32_precision
"ieee"
>>> torch.backends.fp32_precision="tf32"
>>> torch.backends.fp32_precision
"tf32"
```

### Backward Compatible
Since new API allow user to control at ops level. There will be some conflict. For example, previous `torch.backends.cudnn.allow_tf32` control both conv/rnn, and we are providing `torch.backends.cudnn.conv.fp32_precision` and `torch.backends.cudnn.rnn.fp32_precision`. For "set" method of `torch.backends.cudnn.allow_tf32=xyz`, it can work under a BC way but for "get" method. 1 flag `torch.backends.cudnn.allow_tf32` are not enough to represent 2 operators status. We will raise a warning here.
```python
# When user use torch.backends.cudnn.allow_tf32 to "set", we will set both
>>> torch.backends.cudnn.conv.fp32_precision
'tf32'
>>> torch.backends.cudnn.rnn.fp32_precision
'tf32'
>>> torch.backends.cudnn.allow_tf32 = False
>>> torch.backends.cudnn.conv.fp32_precision
'ieee'
>>> torch.backends.cudnn.rnn.fp32_precision
'ieee'
# When user use torch.backends.cudnn.allow_tf32 to "get", we will return true only when both fp32_precision are `tf32`. And if the settings for `conv` and `rnn` are different , we will warn user that the actually situation
>>> torch.backends.cudnn.allow_tf32 = True
>>> torch.backends.cudnn.rnn.fp32_precision = "ieee"
>>> torch.backends.cudnn.allow_tf32
[W511 16:22:07.017584786 Context.cpp:152] Warning: We allow to set different float32 precision for conv and rnn but your are querying float32 precision without a specific op.The current float32 precision for conv is tf32 and for rnn is ieee (function allowTF32CuDNN)
False
```
We have similar situation between `torch.float32_matmul_precision` and `torch.backends.cuda.matmul.fp32_precision`\`torch.backends.mkldnn.matmul.fp32_precision`. The `set` method for `torch.float32_matmul_precision` will work in a BC way and we will raise warning for `get` method of `torch.float32_matmul_precision`.
```
# set method
>>> torch.backends.cuda.matmul.fp32_precision
'ieee'
>>> torch.get_float32_matmul_precision()
'highest'
>>> torch.backends.cuda.matmul.fp32_precision
'ieee'
>>> torch.backends.mkldnn.matmul.fp32_precision
'ieee'
>>> torch.set_float32_matmul_precision("medium")
>>> torch.backends.cuda.matmul.fp32_precision
'tf32'
>>> torch.backends.mkldnn.matmul.fp32_precision
'bf16'
# get method
>>> torch.set_float32_matmul_precision("highest")
>>> torch.backends.cuda.matmul.fp32_precision = "tf32"
>>> torch.get_float32_matmul_precision()
[W511 18:14:56.441053716 Context.cpp:289] Warning: We allow to set different float32 matmul precision for mkldnn and cuda but you are querying float32 matmul precision without a specific backend.The current float32 matmul precision for cuda is tf32 and for mkldnn is ieee (function operator())
'highest'
```





cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
zhuhaozhe added a commit that referenced this pull request May 13, 2024
ghstack-source-id: 4634f6b
Pull Request resolved: #125888
Based on the [conversation](#121791), we plan to drop the "highest, high, medium" to represent fp32  internal computation data types . Instead, we will directly use the algorithm to represent it.

### Design Choice: Directly use algorithms name like "TF32", "BF16". 
#### Pros
 - The names are more informative. 'tf32' is more informative than a simple "HIGH".
 - Easier to extend new algorithm like `tf32x3` 
#### Cons
 - "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them.

### We provide a layered structure for backends/operators.
('f32' is short for 'fp32_precision')
![image](https://github.com/pytorch/pytorch/assets/54701539/9cddf275-071c-4f69-a5ee-1540f78ac7f4)

### We provide 4 fp32 compute precision can be set:
 - **"ieee"**: computation will happened at pure FP32 level, BF16 an TF32 are not allowed.
 - **"tf32"**: allowed to use tf32 as internal computation data types.
 - **"bf16"**: allowed to use bf16 as internal computation data types.
 - **"default"**: no specific precisions are set. We will search it's parent precision under layered structure.

### Examples
```python
# change top level fp32_precision from default value "ieee" to "tf32"
>>> torch.backends.fp32_precision
"ieee"
>>> torch.backends.fp32_precision="tf32"
>>> torch.backends.fp32_precision
"tf32"
```

### Backward Compatible
Since new API allow user to control at ops level. There will be some conflict. For example, previous `torch.backends.cudnn.allow_tf32` control both conv/rnn, and we are providing `torch.backends.cudnn.conv.fp32_precision` and `torch.backends.cudnn.rnn.fp32_precision`. For "set" method of `torch.backends.cudnn.allow_tf32=xyz`, it can work under a BC way but for "get" method. 1 flag `torch.backends.cudnn.allow_tf32` are not enough to represent 2 operators status. We will raise a warning here.
```python
# When user use torch.backends.cudnn.allow_tf32 to "set", we will set both
>>> torch.backends.cudnn.conv.fp32_precision
'tf32'
>>> torch.backends.cudnn.rnn.fp32_precision
'tf32'
>>> torch.backends.cudnn.allow_tf32 = False
>>> torch.backends.cudnn.conv.fp32_precision
'ieee'
>>> torch.backends.cudnn.rnn.fp32_precision
'ieee'
# When user use torch.backends.cudnn.allow_tf32 to "get", we will return true only when both fp32_precision are `tf32`. And if the settings for `conv` and `rnn` are different , we will warn user that the actually situation
>>> torch.backends.cudnn.allow_tf32 = True
>>> torch.backends.cudnn.rnn.fp32_precision = "ieee"
>>> torch.backends.cudnn.allow_tf32
[W511 16:22:07.017584786 Context.cpp:152] Warning: We allow to set different float32 precision for conv and rnn but your are querying float32 precision without a specific op.The current float32 precision for conv is tf32 and for rnn is ieee (function allowTF32CuDNN)
False
```
We have similar situation between `torch.float32_matmul_precision` and `torch.backends.cuda.matmul.fp32_precision` \ `torch.backends.mkldnn.matmul.fp32_precision`. The `set` method for `torch.float32_matmul_precision` will work in a BC way and we will raise warning for `get` method of `torch.float32_matmul_precision`.
```
# set method
>>> torch.backends.cuda.matmul.fp32_precision
'ieee'
>>> torch.get_float32_matmul_precision()
'highest'
>>> torch.backends.cuda.matmul.fp32_precision
'ieee'
>>> torch.backends.mkldnn.matmul.fp32_precision
'ieee'
>>> torch.set_float32_matmul_precision("medium")
>>> torch.backends.cuda.matmul.fp32_precision
'tf32'
>>> torch.backends.mkldnn.matmul.fp32_precision
'bf16'
# get method
>>> torch.set_float32_matmul_precision("highest")
>>> torch.backends.cuda.matmul.fp32_precision = "tf32"
>>> torch.get_float32_matmul_precision()
[W511 18:14:56.441053716 Context.cpp:289] Warning: We allow to set different float32 matmul precision for mkldnn and cuda but you are querying float32 matmul precision without a specific backend.The current float32 matmul precision for cuda is tf32 and for mkldnn is ieee (function operator())
'highest'
```





cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@yanbing-j
Copy link
Collaborator

Hi @atalman , @jithunnair-amd and @jeffdaily,

The direct root cause of PYTORCH_TEST_WITH_ROCM=1 python test/inductor/test_flex_decoding.py TestFlexDecodingCUDA.test_non_sparse_mulitple_block_size_cuda is a mismatch of the logic in allowTF32CuBLAS() with MI300 TF32 support, which uses HIPBLASLT_ALLOW_TF32=1 to indicate tf32 is allowed and enabled for cublas matmul in MI300, and instead of using float32_matmul_precision to judge whether tf32 is allowed. And UT only use torch.set_float32_matmul_precision("high") to set float32_matmul_precision, but env is not set.

Since I'm not familiar with hip cublas related, I just set float32_matmul_precision to highest as it is in main branch when is hip.

I'm not sure that whether HIPBLASLT_ALLOW_TF32=1 and the specific logic of MI300 is a workaround to bypass torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. mentioned in setAllowTF32CuBLAS. Will AMD support MI300 of tf32 without the specific logic in the future? Thanks!

Please correct me if I misunderstand.

[ghstack-poisoned]
[ghstack-poisoned]
@yanbing-j
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

yanbing-j pushed a commit to yanbing-j/pytorch that referenced this pull request Jun 30, 2025
ghstack-source-id: 63962ec
Pull Request resolved: pytorch#125888
@jansel
Copy link
Contributor

jansel commented Jul 12, 2025

@zhuhaozhe @albanD this PR is causing:

>>> import torch
>>> torch.backends.cudnn.allow_tf32 = True
/home/jansel/pytorch/torch/backends/__init__.py:46: UserWarning: This API is going to be deprecated, please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /home/jansel/pytorch/aten/src/ATen/Context.cpp:78.)
  self.setter(val)
>>> 

Which seems bad:

  1. The error message doesn't say which API is being deprecated, so if I get this printout from a large model it is hard to figure out what "This API" means. I was only able to figure out what the error was talking about by grepping the PyTorch source code. I think this will confuse users.
  2. The webpage the error message links to tells me to use torch.backends.cudnn.allow_tf32 = True (the exact thing causing the error) with no message about deprecations.

@yanbing-j
Copy link
Collaborator

yanbing-j commented Jul 14, 2025

@zhuhaozhe @albanD this PR is causing:

>>> import torch
>>> torch.backends.cudnn.allow_tf32 = True
/home/jansel/pytorch/torch/backends/__init__.py:46: UserWarning: This API is going to be deprecated, please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /home/jansel/pytorch/aten/src/ATen/Context.cpp:78.)
  self.setter(val)
>>> 

Which seems bad:

  1. The error message doesn't say which API is being deprecated, so if I get this printout from a large model it is hard to figure out what "This API" means. I was only able to figure out what the error was talking about by grepping the PyTorch source code. I think this will confuse users.
  2. The webpage the error message links to tells me to use torch.backends.cudnn.allow_tf32 = True (the exact thing causing the error) with no message about deprecations.

Hi @jansel, Thanks for pointing this out!

I draft #158209 to complete The API warning, and update the webpage context to be more marked to suggest user to use a new API setting.

Now the warning is updated to

>>> import torch
>>> torch.backends.cudnn.allow_tf32 = True
/home/yanbingj/projects/pytorch/torch/backends/__init__.py:46: UserWarning: Suggest to use a new setting of API control of a more fine-grained TF32 behavior, e.g, torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old setting, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() are still supported, and is going to be deprecated. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /home/yanbingj/projects/pytorch/aten/src/ATen/Context.cpp:78.)
  self.setter(val)

Please let me know if #158209 can help. Thanks!

@jansel
Copy link
Contributor

jansel commented Jul 14, 2025

I left some comments on that PR. Perhaps we should just remove the warning for now, then roll the out as follows:

  1. Add the new APIs
  2. Update the docs to direct people to the new APIs (currently the docs still tell people to use the "deprecated" APIs)
  3. If we plan to keep the old API forever, don't make old API emit a warning
  4. If we plan to delete the old API, emit a warning with a schedule for when we will delete the old API

@yanbing-j
Copy link
Collaborator

We want to deprecate old APIs. I have updated the docs to direct people to the new APIs in #158209. Please take a look again!

pytorchmergebot pushed a commit that referenced this pull request Jul 17, 2025
### Description

This PR is to enable TF32 as fp32 internal precision for matmul/linear/conv in `mkldnn backend`. Since we have refined fp32 precision API in #125888, we can easily extend the API to support TF32 for `mkldnn backend`.

```
torch.backends.mkldnn.matmul.fp32_precision = 'tf32'
torch.backends.mkldnn.conv.fp32_precision = "tf32"
```

Related kernel update and UTs update are done. And the wrapper `bf32_on_and _off` is updated to `reduced_f32_on_and_off`, and it can run tests 3 times, one is reduced_f32 OFF, the other two are reduced_f32 ON (including `bf32 ON` and `tf32 ON`).

Pull Request resolved: #157520
Approved by: https://github.com/mingfeima, https://github.com/jansel
@github-actions github-actions bot deleted the gh/zhuhaozhe/28/head branch August 14, 2025 02:19
pytorchmergebot pushed a commit that referenced this pull request Aug 21, 2025
…l.fp32_precision` (#161102)

For #161022
The warning says the old API will be deprecated in 2.9+ anyway, leaving it up to the author of #125888 to decide on initialization behavior then

Pull Request resolved: #161102
Approved by: https://github.com/ngimel, https://github.com/drisspg, https://github.com/BoyuanFeng
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/inductor ciflow/linux-aarch64 linux aarch64 CI workflow ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo module: inductor module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration open source release notes: python_frontend python frontend release notes category Reverted
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.