Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set Float32 Precision for CONV/RNN #6

Open
zhuhaozhe opened this issue Jan 23, 2024 · 7 comments
Open

Set Float32 Precision for CONV/RNN #6

zhuhaozhe opened this issue Jan 23, 2024 · 7 comments

Comments

@zhuhaozhe
Copy link
Owner

zhuhaozhe commented Jan 23, 2024

RFC: Extend set fp32 precision API to support Convolution and RNN

Overview

This RFC proposes the addition of a user-controlled frontend API to configure the internal precision of float32 operations in convolutional (CONV) and recurrent neural networks (RNN) within PyTorch. Currently, PyTorch offers torch.set_float32_matmul_precision to configure the internal precision of float32 matrix multiplication. This RFC suggests extending this functionality to include convolution and recurrent neural network operations, providing torch.set_float32_conv_precision and torch.set_float32_rnn_precision. The proposed APIs will mimic the behavior of torch.set_float32_matmul_precision.

Frontend Changes

Frontend changes involve introducing new APIs:

  • torch.set_float32_conv_precision, torch.get_float32_conv_precision
  • torch.set_float32_rnn_precision, torch.get_float32_rnn_precision

These APIs will function similarly to torch.set_float32_matmul_precision and torch.get_float32_matmul_precision. Users can set the precision to highest, high, or medium, each with corresponding backend behavior:

  • highest: Use the highest available precision, avoiding lower precision.
  • high: Allow backends to use TensorFloat32 (TF32) or treat each float32 number as the sum of two bfloat16 numbers.
  • medium: Allow backends to use BFloat 16 (BF16).

Backend Changes

Global flags float32_conv/rnn_precision will be introduced at this location in the PyTorch repository. This flag can be accessed and modified by the frontend APIs torch.get/set_float32_conv/rnn_precision. Backend-related operators will read this flag to control the internal computation data types. For example:

  • For CuDNN backend, we should check float32_conv_precision in the CuDNN Conv kernel. We should also check float32_rnn_precision in the CuDNN RNN kernel. If not set to highest, the internal computation data type will be TF32.
  • For OneDNN backend, we should check float32_conv_precision in OneDNN Conv kernel and check float32_rnn_precision in OneDNN RNN kernel. If set to medium, the internal data type will be BF16.

Flag Overrides

The existing CUDNN backend-specific flag torch.backends.cudnn.allow_tf32 will interact with the proposed backend-irrelevant flag torch.set_float32_conv/rnn_precision. These flags will override each other( we follow similar behavior between torch.backends.cuda.matmul.allow_tf32 and float32_matmul_precision):

  • Turning on/off TF32 with torch.backends.cudnn.allow_tf32 will set float32_rnn/conv_precision to high (TF32 enabled) and highest (TF32 disabled).
torch.backends.cudnn.alow_tf32=True
print("float32_conv_precision", torch.get_float32_conv_precision)
print("float32_rnn_precision", torch.get_float32_rnn_precision)
# output:
# float32_conv_precision, high
# float32_rnn_precision, high
torch.backends.cudnn.alow_tf32=False
print("float32_conv_precision", torch.get_float32_conv_precision)
print("float32_rnn_precision", torch.get_float32_rnn_precision)
# output:
# float32_conv_precision, highest
# float32_rnn_precision, highest
  • Setting both float32_rnn/conv_precision to high or medium will enable torch.backends.cudnn.allow_tf32, while setting one of it to highest will disable it.
torch.backends.cudnn.alow_tf32=True
torch.set_float32_conv_precision("highest")
print("torch.backends.cudnn.alow_tf32", torch.backends.cudnn.alow_tf32)
# output:
# torch.backends.cudnn.alow_tf32, False
torch.set_float32_rnn_precision("highest")
print("torch.backends.cudnn.alow_tf32", torch.backends.cudnn.alow_tf32)
# output:
# torch.backends.cudnn.alow_tf32, False
torch.set_float32_conv_precision("high")
torch.set_float32_rnn_precision("high")
print("torch.backends.cudnn.alow_tf32", torch.backends.cudnn.alow_tf32)
# output:
# torch.backends.cudnn.alow_tf32, True

Additional CuDNN Flag

We discussed how the existing CuDNN flag, torch.backends.cudnn.allow_tf32, interacts with torch.set_float32_conv/rnn_precision. However, we believe it is cleaner to use separate flags in CuDNN. We suggest deprecating torch.backends.cudnn.allow_tf32 in favor of torch.backends.cudnn.conv.allow_tf32 and torch.backends.cudnn.rnn.allow_tf32. Then, the CuDNN backend-specific flags and backend-irrelevant flags can have a one-to-one correspondence, such as torch.backends.cuda.matmul.allow_tf32 and torch.float32_matmul_precision

torch.backends.cudnn.conv.allow_tf32 <-> torch.float32_conv_precision
torch.backends.cudnn.rnn.allow_tf32 <-> torch.float32_rnn_precision
# below flags are already existing now
torch.backends.cuda.matmul.allow_tf32 <-> torch.float32_matmul_precision

Motivation

Lower-precision computation from different backends can significantly improve performance for deep learning workloads with minimal impact on precision. For example, TF32 from CUDA/CUDNN or implicit reduced precision arithmetic feature from oneDNN. By providing a user-controlled frontend API, users can easily configure the internal computation data type of convolutional and recurrent neural networks without knowing the detail of different backends. This allows them to leverage the performance benefits of lower precision while ensuring acceptable precision loss. Compared to Autocast, the proposed flags offer:

  • Higher precision control as they only affect convolutional and recurrent neural network internal data types, unlike Autocast, which impacts more operators.
  • Ease of use, as users do not need to modify their model scripts to enable autocasting.

Pitch

Introduce float32_conv/rnn_precision and enable users to control the internal data type for convolutional and recurrent neural networks by configuring the value of float32_conv/rnn_precision.

@zhuhaozhe zhuhaozhe changed the title xx USE the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/CONV/RNN Jan 25, 2024
@leslie-fang-intel
Copy link

@leslie-fang-intel
Copy link

When the precision is high, the CUDA/CUDNN backend will be allowed to use TF32 as the internal computation data type. When the precision is medium, the MKLDNN backend will be allowed to use BF16 as the internal computation data type.

Refer to https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch-set-float32-matmul-precision, it should apply to all the backends?

@leslie-fang-intel
Copy link

For the 2 design options in Frontend API and Inductor linear packable, do we have any preferred option now? If so, we may talk about our preference for implementation.

@zhuhaozhe zhuhaozhe changed the title USE the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/CONV/RNN Set Float32 Precision GEMM/CONV/RNN Feb 1, 2024
@zhuhaozhe
Copy link
Owner Author

When the precision is high, the CUDA/CUDNN backend will be allowed to use TF32 as the internal computation data type. When the precision is medium, the MKLDNN backend will be allowed to use BF16 as the internal computation data type.

Refer to https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch-set-float32-matmul-precision, it should apply to all the backends?

Yes, I changed it to all backends instead of MKLDNN or CUDA

@zhuhaozhe
Copy link
Owner Author

Thanks, changed.

@jgong5
Copy link

jgong5 commented Feb 1, 2024

Please add notes on how CUDA can support the new frontend APIs since it is general APIs that can be applied to all backends.

@zhuhaozhe
Copy link
Owner Author

Please add notes on how CUDA can support the new frontend APIs since it is general APIs that can be applied to all backends.

Thanks for advice, added.

@zhuhaozhe zhuhaozhe changed the title Set Float32 Precision GEMM/CONV/RNN Set Float32 Precision CONV/RNN Mar 11, 2024
@zhuhaozhe zhuhaozhe changed the title Set Float32 Precision CONV/RNN Set Float32 Precision for CONV/RNN Mar 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants