Skip to content

Commit ea47d21

Browse files
authored
Make FLAGS_determinstic effective in conv2d forward. (#37173)
* Make FLAGS_determinstic effective in conv2d forward. * Add call of SetCinnCudnnDeterministic in cinn_launch op.
1 parent 5091fed commit ea47d21

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

paddle/fluid/operators/cinn_launch_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "paddle/fluid/operators/cinn_launch_op.h"
1616
#include "paddle/fluid/string/string_helper.h"
1717

18+
DECLARE_bool(cudnn_deterministic);
19+
1820
namespace paddle {
1921
namespace operators {
2022

@@ -67,6 +69,12 @@ void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
6769
compiled_obj.runtime_program->Execute(&context.FinalizeArguments());
6870
}
6971

72+
void SetCinnRuntimeFlags() {
73+
VLOG(4) << "Set FLAGS_cinn_cudnn_deterministic to "
74+
<< FLAGS_cudnn_deterministic;
75+
::cinn::runtime::SetCinnCudnnDeterministic(FLAGS_cudnn_deterministic);
76+
}
77+
7078
CinnLaunchContext::CinnLaunchContext(const CinnCompiledObject& compiled_obj)
7179
: paddle2cinn_varmap_(compiled_obj.paddle2cinn_varmap),
7280
cinn_scope_(compiled_obj.scope) {

paddle/fluid/operators/cinn_launch_op.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "cinn/hlir/framework/graph_compiler.h"
2222
#include "cinn/hlir/framework/scope.h"
2323
#include "cinn/runtime/cinn_runtime.h"
24+
#include "cinn/runtime/flags.h"
2425
#include "paddle/fluid/framework/data_type.h"
2526
#include "paddle/fluid/framework/op_registry.h"
2627
#include "paddle/fluid/framework/operator.h"
@@ -110,6 +111,9 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result);
110111
// Launch cinn to execute compiled executable program and wait done
111112
void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
112113
const CinnLaunchContext& context);
114+
115+
// Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS.
116+
void SetCinnRuntimeFlags();
113117
} // namespace details
114118

115119
template <typename DeviceContext, typename T>
@@ -202,7 +206,10 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
202206
launch_context->AssignInternalVariable(var_name, tensor);
203207
}
204208

205-
// Step 4. Launch CINN to execute the compiled executable program
209+
// Step 4. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic.
210+
details::SetCinnRuntimeFlags();
211+
212+
// Step 5. Launch CINN to execute the compiled executable program
206213
details::LaunchCinnExecution(cinn_compiled_object, *launch_context);
207214
VLOG(4) << "CinnLaunchOp launch execution done.";
208215
}

paddle/fluid/operators/conv_cudnn_op.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,11 +298,12 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
298298
miopenConvFwdAlgorithm_t algo{};
299299
using search = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
300300
workspace_size = search::GetWorkspaceSize(args);
301-
algo = search::Find<T>(args, exhaustive_search, false, workspace_size, ctx);
301+
algo = search::Find<T>(args, exhaustive_search, deterministic,
302+
workspace_size, ctx);
302303
#else
303304
cudnnConvolutionFwdAlgo_t algo{};
304305
using search = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
305-
algo = search::Find<T>(args, exhaustive_search, false, ctx);
306+
algo = search::Find<T>(args, exhaustive_search, deterministic, ctx);
306307
workspace_size = search::GetWorkspaceSize(args, algo);
307308
#endif
308309

0 commit comments

Comments
 (0)