Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

USE the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/CONV/RNN #5

Open
zhuhaozhe opened this issue Jan 17, 2024 · 0 comments
Assignees

Comments

@zhuhaozhe
Copy link
Owner

zhuhaozhe commented Jan 17, 2024

🚀 The Feature

This RFC proposes to use BFloat16 for GEMM/CONV/RNN internal computations on CPU device with user controlled frontend API. Currently, we have torch.set_float32_matmul_precision which allow float32 matrix multiplications in lower precision.

  • highest -> Do not use lower precision
  • high -> Use TF32 as the internal computation data type.
  • medium -> Designed to use BF16 as the internal computation data type.
    To allow CONV/RNN to also have an internal computation data type for float32 and integrate mkldnn BF16 as an internal computation data type with GEMM/CONV/RNN on CPU device, we proposed below high-level code changes.

Frontend changes:

  • We propose to provide frontend API
    • torch.set_float32_conv_precision, torch.get_float32_conv_precision
    • torch.set_float32_rnn_precision, torch.get_float32_rnn_precision

These frontend API should work under the same behavior with torch.set_float32_matmul_precision and torch.get_float32_matmul_precision. Users can set the precision to highest, high, and medium. When the precision is high, CUDA/CUDNN backend will be allowed to use TF32 as the internal computation data type. When the precision is medium, the MKLDNN backend will be allowed to use BF16 as the internal computation data type.

Backend changes:

  • For matmul. Currently, we only dispatch at::matmul to mkldnn_matmul when input tensors are BFloat16. We propose to further dispatch at::matmul to mkldnn_matmul when:
    • (1)float32_matmul_precision is medium and
    • (2) Input tensors are float32

Then We will use BF16 as the internal computation data type, PR is already created.

  • For Conv. We will check float32_conv_precision in mkldnn_conv and will use BF16 as the internal computation data type.

  • For RNN. We will check float32_rnn_precision in mkldnn_rnn_layer and will use BF16 as the internal computation data type.

Inductor changes:

  • We will pack addmm/mm to mkldnn_linear when float32_matmul_precision is medium

Motivation

A new instruction set of BF16 TMUL on Intel XEON server product can improve user application performance. With these frontend API, users can control internal computation data types for GEMM/CONV/RNN even when the model's data type is FLoat32. This will

  • Have higher precision compared with Autocast features since only GEMM/CONV/RNN can have BF16 internal computation data types while for Autocast, more ops might be computed at the BF16 level.
  • Users can enable BF16 without finding a place to enable autocast in model scripts.

Pitch

Provide float32_conv_precision and float32_rnn_precision and enable bfloat16 datatype for internal computations with MKLDNN backend when precision is set to medium

Additional context

Design option

Front end API:

  • option 1: provide backend irrelevant API get/set_float32_conv/rnn_precision like float32_matmul_precision.
    • Pros:
      • The user-facing API is unified. Users can use lower-precision computation data types without knowing the backend details.
    • Cons:
      • Less of a fine-grained controller for different backend.
  • option 2: provide allow_bf32 in the mkldnn backend like allow_tf32 in cudnn backend.
    • Pros:
      • Find-grained controller: The user will be able to run BF16 as internal computation datatypes on CPU and run FP32 datatypes on the GPU if the model is distributed on multiple kinds of devices.
    • Cons:
      • The Users need to learn about different backend details and more code changes in their app.

Design option

Inductor linear packable rules:

  • option 1: Only pack it to mkldnn_linear when presion is medium.
    • Pros:
      • No performance changes for pure FP32 case. No regression risks.
    • Cons:
      • Less of fusion opportunities.
  • option 2: Always pack it to mkldnn_linear.
    • Pros:
      • mkldnn_linear will introduce more fusion opportunities.
    • Cons:
      • May have regression risks for pure FP32 case.
@zhuhaozhe zhuhaozhe self-assigned this Jan 17, 2024
@zhuhaozhe zhuhaozhe changed the title 1 Ese the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/Conv/RNN Jan 17, 2024
@zhuhaozhe zhuhaozhe changed the title Ese the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/Conv/RNN USE the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/Conv/RNN Jan 17, 2024
@zhuhaozhe zhuhaozhe changed the title USE the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/Conv/RNN USE the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/CONV/RNN Jan 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant