Skip to content

refine fp32 precision api #125888

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 83 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
4a54d4b
refine fp32 precision api
zhuhaozhe May 10, 2024
2ce658b
Update on "[WIP] refine fp32 precision api"
zhuhaozhe May 10, 2024
f4c899c
Update on "[WIP] refine fp32 precision api"
zhuhaozhe May 10, 2024
881ff0d
Update on "[WIP] refine fp32 precision api"
zhuhaozhe May 10, 2024
fd44ff4
Update on "[WIP] refine fp32 precision api"
zhuhaozhe May 11, 2024
834055f
Update on "[WIP] refine fp32 precision api"
zhuhaozhe May 13, 2024
ac00878
Update on "[WIP] refine fp32 precision api"
zhuhaozhe May 13, 2024
172a1ae
Update on "refine fp32 precision api"
zhuhaozhe May 16, 2024
aa97dab
Update on "refine fp32 precision api"
zhuhaozhe May 16, 2024
e702fd4
Update on "refine fp32 precision api"
zhuhaozhe May 16, 2024
4ab1d9d
Update on "refine fp32 precision api"
zhuhaozhe May 16, 2024
830e24f
Update on "refine fp32 precision api"
zhuhaozhe May 16, 2024
b5f57ca
Update on "refine fp32 precision api"
zhuhaozhe May 16, 2024
599ff41
Update on "refine fp32 precision api"
zhuhaozhe May 17, 2024
bce259a
Update on "refine fp32 precision api"
zhuhaozhe May 17, 2024
60fa52f
Update on "refine fp32 precision api"
zhuhaozhe May 20, 2024
2d250ab
Update on "refine fp32 precision api"
zhuhaozhe May 20, 2024
ca297b9
Update on "refine fp32 precision api"
zhuhaozhe May 21, 2024
c21c334
Update
zhuhaozhe May 22, 2024
726c734
Update
zhuhaozhe May 23, 2024
371228a
Update
zhuhaozhe May 26, 2024
11073ff
Update
zhuhaozhe May 27, 2024
ef10a6e
Update
zhuhaozhe May 27, 2024
8b0e737
Update
zhuhaozhe May 28, 2024
5ed7300
Update
zhuhaozhe May 28, 2024
1ab9d9f
Update
zhuhaozhe May 30, 2024
f8cf583
Update
zhuhaozhe Jun 4, 2024
c161fa5
Update
zhuhaozhe Jun 4, 2024
36c1246
Update
zhuhaozhe Jun 5, 2024
50d91f7
Update
zhuhaozhe Aug 1, 2024
6481d31
Update
zhuhaozhe Aug 1, 2024
bcc86d7
Update
zhuhaozhe Aug 6, 2024
d4f5927
Update
zhuhaozhe Aug 6, 2024
836575c
Update
zhuhaozhe Aug 29, 2024
a994500
Update
zhuhaozhe Sep 2, 2024
bc6bbbd
Update
zhuhaozhe Sep 4, 2024
e76e4e5
Update
zhuhaozhe Sep 5, 2024
00fc4c6
Update
zhuhaozhe Sep 9, 2024
a7e9fc5
Update
zhuhaozhe Sep 9, 2024
2114920
Update
zhuhaozhe Sep 9, 2024
0a3197d
Update
zhuhaozhe Sep 10, 2024
c490c10
Update
zhuhaozhe Sep 10, 2024
1d837f1
Update
zhuhaozhe Sep 11, 2024
296dc15
Update
zhuhaozhe Sep 12, 2024
b885c66
Update
zhuhaozhe Sep 19, 2024
aa82134
Update
zhuhaozhe Sep 20, 2024
15b96ee
Update
zhuhaozhe Sep 20, 2024
fc042ee
Update
zhuhaozhe Sep 23, 2024
0698768
Update
zhuhaozhe Sep 23, 2024
b5bf46c
Update
zhuhaozhe Sep 23, 2024
69cef61
Update
zhuhaozhe Sep 24, 2024
8b2c06d
Update
zhuhaozhe Sep 25, 2024
bc3382f
Update
zhuhaozhe Nov 1, 2024
198d3c7
Update
zhuhaozhe Nov 1, 2024
2196da3
Update on "refine fp32 precision api"
zhuhaozhe Nov 29, 2024
389ddc3
Update
yanbing-j Dec 5, 2024
f2dd378
Update
yanbing-j Dec 23, 2024
b7881a1
Update
yanbing-j Dec 24, 2024
11d424e
Update
yanbing-j Dec 26, 2024
5a09f53
Update
yanbing-j Jan 3, 2025
6d03259
Update
yanbing-j Jan 20, 2025
b78e817
Update
yanbing-j Feb 6, 2025
6b0cff9
Update
yanbing-j Feb 8, 2025
b9bb74b
Update
yanbing-j Feb 8, 2025
015af59
Update
yanbing-j Mar 7, 2025
ab7da48
Update
yanbing-j Mar 10, 2025
7836e47
Update
yanbing-j Mar 13, 2025
fafaf1a
Update
yanbing-j Mar 28, 2025
db5c3b8
Update
yanbing-j Apr 30, 2025
7e9443d
Update
yanbing-j May 7, 2025
42707e8
Update
yanbing-j May 7, 2025
87e3228
Update
yanbing-j May 9, 2025
4130003
Update
yanbing-j May 10, 2025
73d8d22
Update
yanbing-j May 14, 2025
7d64a33
Update
yanbing-j May 29, 2025
2bb8483
Update
yanbing-j Jun 13, 2025
a33cc14
Update
yanbing-j Jun 19, 2025
af62606
Update
yanbing-j Jun 20, 2025
cfdc4fa
Update
yanbing-j Jun 21, 2025
ee6c03b
Update
yanbing-j Jun 22, 2025
ca7e0ba
Update
yanbing-j Jun 23, 2025
9198c8d
Update
yanbing-j Jun 25, 2025
467fe0e
Update
yanbing-j Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 139 additions & 5 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,69 @@
#if defined(__aarch64__) && !defined(C10_MOBILE)
#include <cpuinfo.h>
#endif

namespace at {

namespace {

/*
These const variables defined the fp32 precisions for different backend
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
prevision from "ieee", "tf32", "bf16" and "none". The "ieee" precision means
IEEE standard floating point format "tf32" and "bf16" means we are allowed to
use "tf32" or "bf16" as internal computation data types for fp32 computations.
And "none" means it is override-able by parent's node

generic->mkldnn->matmul
->conv
->rnn
->cuda ->matmul
->conv
->rnn
*/
const std::map<std::string, std::vector<std::string>> _fp32_precisions = {
{"generic", {{"ieee", "tf32", "bf16", "none"}}},
{"mkldnn", {{"ieee", "bf16", "none"}}},
{"cuda", {{"ieee", "tf32", "none"}}}};

// Check whether the backend and op are legal
void check_fp32_prec_backend_and_op(
const std::string& backend,
const std::string& op) {
static std::vector<std::string> backends = {"generic", "mkldnn", "cuda"};
static std::vector<std::string> operators = {"conv", "matmul", "rnn", "all"};
TORCH_CHECK(
std::find(backends.begin(), backends.end(), backend) != backends.end(),
"Invalid backend: ",
backend);
TORCH_CHECK(
std::find(operators.begin(), operators.end(), op) != operators.end(),
"Invalid operator: ",
op);
if (backend == "generic") {
TORCH_CHECK(op == "all", "Invalid operation for generic backend: ", op);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a long term contraint or just to keep this PR smaller and we can add it in a follow up?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our design before, we didn't refer to set precision for operators in "generic" backend. Do you think we need to support this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds ok to start with. But not sure why we would limite ourselves to that.

}
}

// Return whether the precision is supported by backends
bool validate_fp32_prec(
const std::string& backend,
const std::string& precision) {
auto iterp = _fp32_precisions.find(backend);
TORCH_CHECK(iterp != _fp32_precisions.end());
auto precisions = iterp->second;
bool valid = std::find(precisions.begin(), precisions.end(), precision) !=
precisions.end();
return valid;
}

C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
TORCH_WARN_ONCE(
"This API is going to be deprecated, please see "
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
);
}
} // namespace

Context::Context() = default;

// TODO: This could be bad juju if someone calls globalContext() in the
Expand Down Expand Up @@ -115,12 +175,29 @@ void Context::setUserEnabledNNPACK(bool e) {
enabled_nnpack = e;
}

bool Context::allowTF32CuDNN() const {
bool Context::allowTF32CuDNN(const std::string& op) const {
if (op.size() == 0){
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
TORCH_CHECK(
allow_tf32_rnn == allow_tf32_conv && allow_tf32_rnn == allow_tf32_cudnn,
"PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,",
"but the current flag(s) indicate that cuDNN conv and cuDNN RNN have different TF32 flags.",
"This combination indicates that you have used a mix of the legacy and new APIs to set the TF32 flags. ",
"We suggest only using the new API to set the TF32 flag(s). See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
} else {
return float32Precision("cuda", op) == "tf32";
}
warn_deprecated_fp32_precision_api();
return allow_tf32_cudnn;
}

void Context::setAllowTF32CuDNN(bool b) {
setFloat32Precision("cuda", "rnn", b ? "tf32" : "none");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that cuda has multiple backends: cudnn, cublas, more? do we want to allow nesting these as well in the future? The current design shouldn't be blocking us from doing that right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current design contains 3 layers, backend, operator and precision. If we want to allow nesting backends for cuda, the current design will not block. Now, backend is str, we can construct a class for backend to nest it.

setFloat32Precision("cuda", "conv", b ? "tf32" : "none");
allow_tf32_cudnn = b;
warn_deprecated_fp32_precision_api();
}

void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
Expand Down Expand Up @@ -259,7 +336,16 @@ bool Context::allowTF32CuBLAS() const {
return false;
}
#endif
return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
TORCH_CHECK(
legacy_allow_tf32 == allow_tf32_new,
"PyTorch is checking whether allow_tf32_new is enabled for cuBlas matmul,",
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
"We suggest only using the new API to set the TF32 flag. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return allow_tf32_new;
}

void Context::setAllowTF32CuBLAS(bool b) {
Expand All @@ -272,27 +358,54 @@ void Context::setAllowTF32CuBLAS(bool b) {
}
#endif
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
}

Float32MatmulPrecision Context::float32MatmulPrecision() const {
bool invalid = float32Precision("cuda", "matmul") == "tf32" &&
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST;
invalid = invalid ||
(float32Precision("mkldnn", "matmul") == "bf16" &&
float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM);
TORCH_CHECK(
!invalid,
"PyTorch is checking the matmul precision without a specific backend name,",
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
"We suggest only using the new API for matmul precision. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return float32_matmul_precision;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, is there a reason we can't just convert and only maintain one set of flags?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

}

void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) {
float32_matmul_precision = p;
std::string Context::float32Precision(const std::string& backend, const std::string& op) const {
check_fp32_prec_backend_and_op(backend, op);
auto precision = fp32_precision.find(backend)->second.find(op)->second;
if (precision == "none")
precision = fp32_precision.find(backend)->second.find("all")->second;
if (precision == "none")
precision = fp32_precision.find("generic")->second.find("all")->second;
bool valid_prec = validate_fp32_prec(backend, precision);
return valid_prec ? precision : "none";
}

void Context::setFloat32MatmulPrecision(const std::string &s) {
auto match = [this](const std::string & s_) {
warn_deprecated_fp32_precision_api();
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
if (s_ == "highest") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
setFloat32Precision("cuda", "matmul", "ieee");
setFloat32Precision("mkldnn", "matmul", "ieee");
return true;
} else if (s_ == "high") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
setFloat32Precision("cuda", "matmul", "tf32");
setFloat32Precision("mkldnn", "matmul", "ieee");
return true;
} else if (s_ == "medium") {
float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
setFloat32Precision("cuda", "matmul", "tf32");
setFloat32Precision("mkldnn", "matmul", "bf16");
return true;
}
return false;
Expand All @@ -306,6 +419,27 @@ void Context::setFloat32MatmulPrecision(const std::string &s) {
"setFloat32MatmulPrecision call has no effect.");
}

void Context::setFloat32Precision(const std::string& backend, const std::string& op, const std::string& p) {
check_fp32_prec_backend_and_op(backend, op);
if (validate_fp32_prec(backend, p)) {
fp32_precision[backend][op] = p;
} else {
std::string msg;
auto iterp = _fp32_precisions.find(backend);
TORCH_CHECK(iterp != _fp32_precisions.end());
for (auto p : iterp->second) {
msg += p;
msg += " ";
}
TORCH_WARN(
"you have set wrong precision for backend:",
backend,
" setFloat32Precision call has no effect.",
"Please choose precision from: ",
msg);
}
}

at::LinalgBackend Context::linalgPreferredBackend() const {
return linalg_preferred_backend;
}
Expand Down
28 changes: 26 additions & 2 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <c10/util/irange.h>

#include <cstdint>
#include <map>
#include <mutex>

namespace at {
Expand Down Expand Up @@ -336,14 +337,20 @@ class TORCH_API Context {
void alertCuBLASConfigNotDeterministic() const;

void setFloat32MatmulPrecision(const std::string& s);
bool allowTF32CuDNN() const;
void setFloat32Precision(
const std::string& backend,
const std::string& op,
const std::string& s);
bool allowTF32CuDNN(const std::string& op = std::string()) const;
void setAllowTF32CuDNN(bool);
bool allowTF32OneDNN() const;
void setAllowTF32OneDNN(bool);
bool allowTF32CuBLAS() const;
void setAllowTF32CuBLAS(bool);
Float32MatmulPrecision float32MatmulPrecision() const;
void setFloat32MatmulPrecision(Float32MatmulPrecision p);
std::string float32Precision(
const std::string& backend,
const std::string& op) const;
bool allowFP16ReductionCuBLAS() const;
void setAllowFP16ReductionCuBLAS(bool);
bool allowBF16ReductionCuBLAS() const;
Expand Down Expand Up @@ -469,6 +476,23 @@ class TORCH_API Context {
bool enable_sparse_tensor_invariant_checks = false;
bool allow_fp16_reduction_cpu = false;

std::map<std::string, std::map<std::string, std::string>> fp32_precision = {
{"generic", {{"all", "none"}}},
{"mkldnn",
{{"matmul", "none"},
{"conv", "none"},
{"rnn", "none"},
{"all", "none"}}},
{"cuda",
{{"matmul",
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we guaranteed that float32_matmul_precision defined above will be initialized by now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So inline class attribute initialization is ordered?

? "none"
: "tf32"},
{"conv", "tf32"},
{"rnn", "tf32"},
{"all", "none"}}},
};

Allocator* prev_allocator_ptr_{nullptr};
};

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
if (at::globalContext().allowTF32CuBLAS()) {
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
Expand Down Expand Up @@ -1589,7 +1589,7 @@ bool gemm_and_bias(
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
if (at::globalContext().allowTF32CuBLAS()) {
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/cuda/CublasHandlePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ cublasHandle_t getCurrentCUDABlasHandle() {
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
if (!NoTF32Guard::should_disable_tf32() && at::globalContext().allowTF32CuBLAS()) {
if (!NoTF32Guard::should_disable_tf32() &&
at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
} else {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/tunable/GemmCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ inline std::string ComputeTypeFor() {
// ROCBLAS and hipBLASLt.
template <>
inline std::string ComputeTypeFor<float>() {
if (!at::globalContext().allowTF32CuBLAS()) {
if (at::globalContext().float32Precision("cuda", "matmul") != "tf32") {
return "f32_r";
} else {
return "xf32_r";
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/tunable/GemmHipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
}

hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
if (at::globalContext().allowTF32CuBLAS()) {
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
}
HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/tunable/GemmRocblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class RocblasGemmOp : public Callable<GemmParams<T>> {

TuningStatus Call(const GemmParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
return FAIL; // no support for TF32 in rocBLAS
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
Expand Down Expand Up @@ -209,7 +209,7 @@ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>

TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
return FAIL; // no support for TF32 in rocBLAS
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,7 @@ at::Tensor convolution(
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
return at::_convolution(input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups,
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN("conv"));
}

at::Tensor convolution_overrideable(
Expand Down Expand Up @@ -1319,7 +1319,7 @@ ConvBackend select_conv_backend(
params.benchmark = ctx.benchmarkCuDNN();
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
params.cudnn_enabled = ctx.userEnabledCuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN("conv");

auto input = input_r;
auto weight = weight_r;
Expand Down Expand Up @@ -1705,7 +1705,7 @@ at::Tensor _convolution(
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;

return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN());
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN("conv"));
}

std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
Expand Down Expand Up @@ -2003,7 +2003,7 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward(
params.benchmark = ctx.benchmarkCuDNN();
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
params.cudnn_enabled = ctx.userEnabledCuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN("conv");

// Validate inputs.
check_shape_backward(input, weight.sizes(), params);
Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/native/cudnn/ConvShared.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ std::string repro_from_args(const ConvolutionParams& params) {
ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n";
ss << "import torch\n";
ss << "torch.backends.cuda.matmul.allow_tf32 = "
<< pybool(at::globalContext().allowTF32CuBLAS()) << "\n";
<< pybool(at::globalContext().float32Precision("cuda", "matmul") == "tf32")
<< "\n";
ss << "torch.backends.cudnn.benchmark = "
<< pybool(at::globalContext().benchmarkCuDNN()) << "\n";
ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic)
Expand Down Expand Up @@ -725,7 +726,7 @@ Tensor cudnn_convolution_relu(

auto& ctx = at::globalContext();
bool benchmark = ctx.benchmarkCuDNN();
bool allow_tf32 = ctx.allowTF32CuDNN();
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
auto _bias = bias_t.has_value()
? bias_t.value()
: at::zeros(
Expand Down Expand Up @@ -783,7 +784,7 @@ Tensor cudnn_convolution_add_relu(
}

auto& ctx = at::globalContext();
bool allow_tf32 = ctx.allowTF32CuDNN();
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
bool benchmark = ctx.benchmarkCuDNN();
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
auto _bias = bias_t.has_value()
Expand Down
Loading
Loading