-
Notifications
You must be signed in to change notification settings - Fork 25.1k
refine fp32 precision api #125888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refine fp32 precision api #125888
Changes from all commits
4a54d4b
2ce658b
f4c899c
881ff0d
fd44ff4
834055f
ac00878
172a1ae
aa97dab
e702fd4
4ab1d9d
830e24f
b5f57ca
599ff41
bce259a
60fa52f
2d250ab
ca297b9
c21c334
726c734
371228a
11073ff
ef10a6e
8b0e737
5ed7300
1ab9d9f
f8cf583
c161fa5
36c1246
50d91f7
6481d31
bcc86d7
d4f5927
836575c
a994500
bc6bbbd
e76e4e5
00fc4c6
a7e9fc5
2114920
0a3197d
c490c10
1d837f1
296dc15
b885c66
aa82134
15b96ee
fc042ee
0698768
b5bf46c
69cef61
8b2c06d
bc3382f
198d3c7
2196da3
389ddc3
f2dd378
b7881a1
11d424e
5a09f53
6d03259
b78e817
6b0cff9
b9bb74b
015af59
ab7da48
7836e47
fafaf1a
db5c3b8
7e9443d
42707e8
87e3228
4130003
73d8d22
7d64a33
2bb8483
a33cc14
af62606
cfdc4fa
ee6c03b
ca7e0ba
9198c8d
467fe0e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,9 +19,69 @@ | |
#if defined(__aarch64__) && !defined(C10_MOBILE) | ||
#include <cpuinfo.h> | ||
#endif | ||
|
||
namespace at { | ||
|
||
namespace { | ||
|
||
/* | ||
These const variables defined the fp32 precisions for different backend | ||
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32 | ||
prevision from "ieee", "tf32", "bf16" and "none". The "ieee" precision means | ||
IEEE standard floating point format "tf32" and "bf16" means we are allowed to | ||
use "tf32" or "bf16" as internal computation data types for fp32 computations. | ||
And "none" means it is override-able by parent's node | ||
|
||
generic->mkldnn->matmul | ||
->conv | ||
->rnn | ||
->cuda ->matmul | ||
->conv | ||
->rnn | ||
*/ | ||
const std::map<std::string, std::vector<std::string>> _fp32_precisions = { | ||
{"generic", {{"ieee", "tf32", "bf16", "none"}}}, | ||
{"mkldnn", {{"ieee", "bf16", "none"}}}, | ||
{"cuda", {{"ieee", "tf32", "none"}}}}; | ||
|
||
// Check whether the backend and op are legal | ||
void check_fp32_prec_backend_and_op( | ||
const std::string& backend, | ||
const std::string& op) { | ||
static std::vector<std::string> backends = {"generic", "mkldnn", "cuda"}; | ||
static std::vector<std::string> operators = {"conv", "matmul", "rnn", "all"}; | ||
TORCH_CHECK( | ||
std::find(backends.begin(), backends.end(), backend) != backends.end(), | ||
"Invalid backend: ", | ||
backend); | ||
TORCH_CHECK( | ||
std::find(operators.begin(), operators.end(), op) != operators.end(), | ||
"Invalid operator: ", | ||
op); | ||
if (backend == "generic") { | ||
TORCH_CHECK(op == "all", "Invalid operation for generic backend: ", op); | ||
} | ||
} | ||
|
||
// Return whether the precision is supported by backends | ||
bool validate_fp32_prec( | ||
const std::string& backend, | ||
const std::string& precision) { | ||
auto iterp = _fp32_precisions.find(backend); | ||
TORCH_CHECK(iterp != _fp32_precisions.end()); | ||
auto precisions = iterp->second; | ||
bool valid = std::find(precisions.begin(), precisions.end(), precision) != | ||
precisions.end(); | ||
return valid; | ||
} | ||
|
||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){ | ||
TORCH_WARN_ONCE( | ||
"This API is going to be deprecated, please see " | ||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices" | ||
); | ||
} | ||
} // namespace | ||
|
||
Context::Context() = default; | ||
|
||
// TODO: This could be bad juju if someone calls globalContext() in the | ||
|
@@ -115,12 +175,29 @@ void Context::setUserEnabledNNPACK(bool e) { | |
enabled_nnpack = e; | ||
} | ||
|
||
bool Context::allowTF32CuDNN() const { | ||
bool Context::allowTF32CuDNN(const std::string& op) const { | ||
if (op.size() == 0){ | ||
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32"; | ||
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32"; | ||
TORCH_CHECK( | ||
allow_tf32_rnn == allow_tf32_conv && allow_tf32_rnn == allow_tf32_cudnn, | ||
"PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,", | ||
"but the current flag(s) indicate that cuDNN conv and cuDNN RNN have different TF32 flags.", | ||
"This combination indicates that you have used a mix of the legacy and new APIs to set the TF32 flags. ", | ||
"We suggest only using the new API to set the TF32 flag(s). See also: ", | ||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); | ||
} else { | ||
return float32Precision("cuda", op) == "tf32"; | ||
} | ||
warn_deprecated_fp32_precision_api(); | ||
return allow_tf32_cudnn; | ||
} | ||
|
||
void Context::setAllowTF32CuDNN(bool b) { | ||
setFloat32Precision("cuda", "rnn", b ? "tf32" : "none"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that cuda has multiple backends: cudnn, cublas, more? do we want to allow nesting these as well in the future? The current design shouldn't be blocking us from doing that right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current design contains 3 layers, backend, operator and precision. If we want to allow nesting backends for cuda, the current design will not block. Now, backend is str, we can construct a class for backend to nest it. |
||
setFloat32Precision("cuda", "conv", b ? "tf32" : "none"); | ||
allow_tf32_cudnn = b; | ||
warn_deprecated_fp32_precision_api(); | ||
} | ||
|
||
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) { | ||
|
@@ -259,7 +336,16 @@ bool Context::allowTF32CuBLAS() const { | |
return false; | ||
} | ||
#endif | ||
return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; | ||
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; | ||
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32"; | ||
TORCH_CHECK( | ||
legacy_allow_tf32 == allow_tf32_new, | ||
"PyTorch is checking whether allow_tf32_new is enabled for cuBlas matmul,", | ||
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ", | ||
"We suggest only using the new API to set the TF32 flag. See also: ", | ||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); | ||
warn_deprecated_fp32_precision_api(); | ||
return allow_tf32_new; | ||
} | ||
|
||
void Context::setAllowTF32CuBLAS(bool b) { | ||
|
@@ -272,27 +358,54 @@ void Context::setAllowTF32CuBLAS(bool b) { | |
} | ||
#endif | ||
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST; | ||
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee"); | ||
} | ||
|
||
Float32MatmulPrecision Context::float32MatmulPrecision() const { | ||
bool invalid = float32Precision("cuda", "matmul") == "tf32" && | ||
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST; | ||
invalid = invalid || | ||
(float32Precision("mkldnn", "matmul") == "bf16" && | ||
float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM); | ||
TORCH_CHECK( | ||
!invalid, | ||
"PyTorch is checking the matmul precision without a specific backend name,", | ||
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ", | ||
"We suggest only using the new API for matmul precision. See also: ", | ||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); | ||
warn_deprecated_fp32_precision_api(); | ||
return float32_matmul_precision; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, is there a reason we can't just convert and only maintain one set of flags? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||
} | ||
|
||
void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) { | ||
float32_matmul_precision = p; | ||
std::string Context::float32Precision(const std::string& backend, const std::string& op) const { | ||
check_fp32_prec_backend_and_op(backend, op); | ||
auto precision = fp32_precision.find(backend)->second.find(op)->second; | ||
if (precision == "none") | ||
precision = fp32_precision.find(backend)->second.find("all")->second; | ||
if (precision == "none") | ||
precision = fp32_precision.find("generic")->second.find("all")->second; | ||
bool valid_prec = validate_fp32_prec(backend, precision); | ||
return valid_prec ? precision : "none"; | ||
} | ||
|
||
void Context::setFloat32MatmulPrecision(const std::string &s) { | ||
auto match = [this](const std::string & s_) { | ||
warn_deprecated_fp32_precision_api(); | ||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention | ||
if (s_ == "highest") { | ||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST; | ||
setFloat32Precision("cuda", "matmul", "ieee"); | ||
setFloat32Precision("mkldnn", "matmul", "ieee"); | ||
return true; | ||
} else if (s_ == "high") { | ||
float32_matmul_precision = at::Float32MatmulPrecision::HIGH; | ||
setFloat32Precision("cuda", "matmul", "tf32"); | ||
setFloat32Precision("mkldnn", "matmul", "ieee"); | ||
return true; | ||
} else if (s_ == "medium") { | ||
float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM; | ||
setFloat32Precision("cuda", "matmul", "tf32"); | ||
setFloat32Precision("mkldnn", "matmul", "bf16"); | ||
return true; | ||
} | ||
return false; | ||
|
@@ -306,6 +419,27 @@ void Context::setFloat32MatmulPrecision(const std::string &s) { | |
"setFloat32MatmulPrecision call has no effect."); | ||
} | ||
|
||
void Context::setFloat32Precision(const std::string& backend, const std::string& op, const std::string& p) { | ||
check_fp32_prec_backend_and_op(backend, op); | ||
if (validate_fp32_prec(backend, p)) { | ||
fp32_precision[backend][op] = p; | ||
} else { | ||
std::string msg; | ||
auto iterp = _fp32_precisions.find(backend); | ||
TORCH_CHECK(iterp != _fp32_precisions.end()); | ||
for (auto p : iterp->second) { | ||
msg += p; | ||
msg += " "; | ||
} | ||
TORCH_WARN( | ||
"you have set wrong precision for backend:", | ||
backend, | ||
" setFloat32Precision call has no effect.", | ||
"Please choose precision from: ", | ||
msg); | ||
} | ||
} | ||
|
||
at::LinalgBackend Context::linalgPreferredBackend() const { | ||
return linalg_preferred_backend; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
#include <c10/util/irange.h> | ||
|
||
#include <cstdint> | ||
#include <map> | ||
#include <mutex> | ||
|
||
namespace at { | ||
|
@@ -336,14 +337,20 @@ class TORCH_API Context { | |
void alertCuBLASConfigNotDeterministic() const; | ||
|
||
void setFloat32MatmulPrecision(const std::string& s); | ||
bool allowTF32CuDNN() const; | ||
void setFloat32Precision( | ||
const std::string& backend, | ||
const std::string& op, | ||
const std::string& s); | ||
bool allowTF32CuDNN(const std::string& op = std::string()) const; | ||
void setAllowTF32CuDNN(bool); | ||
bool allowTF32OneDNN() const; | ||
void setAllowTF32OneDNN(bool); | ||
bool allowTF32CuBLAS() const; | ||
void setAllowTF32CuBLAS(bool); | ||
Float32MatmulPrecision float32MatmulPrecision() const; | ||
void setFloat32MatmulPrecision(Float32MatmulPrecision p); | ||
std::string float32Precision( | ||
const std::string& backend, | ||
const std::string& op) const; | ||
bool allowFP16ReductionCuBLAS() const; | ||
void setAllowFP16ReductionCuBLAS(bool); | ||
bool allowBF16ReductionCuBLAS() const; | ||
|
@@ -469,6 +476,23 @@ class TORCH_API Context { | |
bool enable_sparse_tensor_invariant_checks = false; | ||
bool allow_fp16_reduction_cpu = false; | ||
|
||
std::map<std::string, std::map<std::string, std::string>> fp32_precision = { | ||
zhuhaozhe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{"generic", {{"all", "none"}}}, | ||
{"mkldnn", | ||
{{"matmul", "none"}, | ||
{"conv", "none"}, | ||
{"rnn", "none"}, | ||
{"all", "none"}}}, | ||
{"cuda", | ||
{{"matmul", | ||
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we guaranteed that float32_matmul_precision defined above will be initialized by now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So inline class attribute initialization is ordered? |
||
? "none" | ||
: "tf32"}, | ||
{"conv", "tf32"}, | ||
{"rnn", "tf32"}, | ||
{"all", "none"}}}, | ||
}; | ||
|
||
Allocator* prev_allocator_ptr_{nullptr}; | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a long term contraint or just to keep this PR smaller and we can add it in a follow up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our design before, we didn't refer to set precision for operators in "generic" backend. Do you think we need to support this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds ok to start with. But not sure why we would limite ourselves to that.