Skip to content

Commit 0a96ec6

Browse files
authored
fix conv v7 workspace size limit error, test=develop (#17902)
1 parent 4d5f693 commit 0a96ec6

File tree

1 file changed

+103
-35
lines changed

1 file changed

+103
-35
lines changed

paddle/fluid/operators/conv_cudnn_op.cu.cc

Lines changed: 103 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
165165

166166
// TODO(dangqingqing) simplify the following code by SearchAlgorithm in
167167
// conv_cudnn_helper.h
168+
bool has_got_workspace_size = false;
168169
if ((!exhaustive_search) && (!half_float)) {
169170
#if CUDNN_VERSION >= 7001
170171
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
@@ -176,11 +177,29 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
176177
cudnn_output_desc, kNUM_CUDNN_FWD_ALGS, &perf_count,
177178
perf_results.get()));
178179
algo = (perf_results.get())[best_algo_idx].algo;
179-
#else
180-
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
180+
181+
// get workspace size able to allocate
182+
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
181183
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
182-
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
183-
workspace_size_limit, &algo));
184+
cudnn_output_desc, algo, &workspace_size_in_bytes));
185+
186+
// NOTE(zjl): cudnnGetConvolutionForwardAlgorithm_v7 cannot limit
187+
// workspace size. If the workspace size found by v7 exceeds the limit,
188+
// we should fallback to non-v7 method to find another algorithm.
189+
if (workspace_size_in_bytes > workspace_size_limit) {
190+
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
191+
"the workspace size request("
192+
<< workspace_size_in_bytes << ") exceeds the limit("
193+
<< workspace_size_limit << ")";
194+
#endif
195+
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
196+
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
197+
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
198+
workspace_size_limit, &algo));
199+
#if CUDNN_VERSION >= 7001
200+
} else {
201+
has_got_workspace_size = true;
202+
}
184203
#endif
185204

186205
VLOG(3) << "cuDNN forward algo " << algo;
@@ -219,10 +238,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
219238
"cuDNN exhaustive search doesn't support half float.");
220239
}
221240

222-
// get workspace size able to allocate
223-
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
224-
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
225-
cudnn_output_desc, algo, &workspace_size_in_bytes));
241+
if (!has_got_workspace_size) {
242+
// get workspace size able to allocate
243+
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
244+
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
245+
cudnn_output_desc, algo, &workspace_size_in_bytes));
246+
}
247+
226248
// It is possible for float16 on Volta GPU to allocate more memory than
227249
// the limit because the algo is overrided to use tensor core.
228250
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
@@ -366,6 +388,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
366388
auto x_dims = framework::vectorize(input->dims());
367389
auto f_dims = framework::vectorize(filter->dims());
368390
auto handle = dev_ctx.cudnn_handle();
391+
392+
bool has_got_bwd_data_ws_size = false;
369393
if (input_grad) {
370394
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
371395
if (exhaustive_search) {
@@ -431,28 +455,49 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
431455
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
432456
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
433457
}
434-
#else
458+
435459
CUDNN_ENFORCE(
436-
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
437-
handle, cudnn_filter_desc,
438-
// dyDesc: Handle to the previously initialized input
439-
// differential
440-
// tensor descriptor.
441-
cudnn_output_grad_desc, cudnn_conv_desc,
442-
// dxDesc: Handle to the previously initialized output tensor
443-
// descriptor.
444-
cudnn_input_desc,
445-
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
446-
workspace_size_limit, &data_algo));
460+
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
461+
handle, cudnn_filter_desc, cudnn_output_grad_desc,
462+
cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
463+
auto new_workspace_size = std::max(workspace_size_in_bytes, tmp_size);
464+
465+
if (new_workspace_size > workspace_size_limit) {
466+
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
467+
"the workspace size request("
468+
<< new_workspace_size << ") exceeds the limit("
469+
<< workspace_size_limit << ")";
470+
#endif
471+
CUDNN_ENFORCE(
472+
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
473+
handle, cudnn_filter_desc,
474+
// dyDesc: Handle to the previously initialized input
475+
// differential
476+
// tensor descriptor.
477+
cudnn_output_grad_desc, cudnn_conv_desc,
478+
// dxDesc: Handle to the previously initialized output tensor
479+
// descriptor.
480+
cudnn_input_desc,
481+
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
482+
workspace_size_limit, &data_algo));
483+
#if CUDNN_VERSION >= 7001
484+
} else {
485+
workspace_size_in_bytes = new_workspace_size;
486+
has_got_bwd_data_ws_size = true;
487+
}
447488
#endif
448489
}
449-
CUDNN_ENFORCE(
450-
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
451-
handle, cudnn_filter_desc, cudnn_output_grad_desc,
452-
cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
453-
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
490+
491+
if (!has_got_bwd_data_ws_size) {
492+
CUDNN_ENFORCE(
493+
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
494+
handle, cudnn_filter_desc, cudnn_output_grad_desc,
495+
cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
496+
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
497+
}
454498
}
455499

500+
bool has_got_bwd_filter_ws_size = false;
456501
if (filter_grad) {
457502
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
458503
if (exhaustive_search) {
@@ -495,22 +540,45 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
495540
cudnn_conv_desc, cudnn_filter_desc, kNUM_CUDNN_BWD_FILTER_ALGS,
496541
&perf_count, perf_results.get()));
497542
filter_algo = (perf_results.get())[best_algo_idx].algo;
498-
#else
543+
499544
CUDNN_ENFORCE(
500-
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
545+
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
501546
handle, cudnn_input_desc, cudnn_output_grad_desc,
502-
cudnn_conv_desc, cudnn_filter_desc,
503-
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
504-
workspace_size_limit, &filter_algo));
547+
cudnn_conv_desc, cudnn_filter_desc, filter_algo, &tmp_size));
548+
auto new_workspace_size = std::max(workspace_size_in_bytes, tmp_size);
549+
550+
if (new_workspace_size > workspace_size_limit) {
551+
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
552+
"the workspace size request("
553+
<< new_workspace_size << ") exceeds the limit("
554+
<< workspace_size_limit << ")";
505555
#endif
556+
CUDNN_ENFORCE(
557+
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
558+
handle, cudnn_input_desc, cudnn_output_grad_desc,
559+
cudnn_conv_desc, cudnn_filter_desc,
560+
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
561+
workspace_size_limit, &filter_algo));
562+
#if CUDNN_VERSION >= 7001
563+
} else {
564+
workspace_size_in_bytes = new_workspace_size;
565+
has_got_bwd_filter_ws_size = true;
566+
}
567+
#endif
568+
}
569+
570+
if (!has_got_bwd_filter_ws_size) {
571+
CUDNN_ENFORCE(
572+
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
573+
handle, cudnn_input_desc, cudnn_output_grad_desc,
574+
cudnn_conv_desc, cudnn_filter_desc, filter_algo, &tmp_size));
575+
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
506576
}
507-
CUDNN_ENFORCE(
508-
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
509-
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
510-
cudnn_filter_desc, filter_algo, &tmp_size));
511-
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
512577
}
513578

579+
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
580+
"workspace_size to be allocated exceeds the limit");
581+
514582
// ------------------- cudnn conv workspace ---------------------
515583
if (!cudnn_workspace_ptr) {
516584
cudnn_workspace =

0 commit comments

Comments
 (0)