@@ -165,6 +165,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
165165
166166 // TODO(dangqingqing) simplify the following code by SearchAlgorithm in
167167 // conv_cudnn_helper.h
168+ bool has_got_workspace_size = false ;
168169 if ((!exhaustive_search) && (!half_float)) {
169170#if CUDNN_VERSION >= 7001
170171 using perf_t = cudnnConvolutionFwdAlgoPerf_t;
@@ -176,11 +177,29 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
176177 cudnn_output_desc, kNUM_CUDNN_FWD_ALGS , &perf_count,
177178 perf_results.get ()));
178179 algo = (perf_results.get ())[best_algo_idx].algo ;
179- #else
180- CUDNN_ENFORCE (platform::dynload::cudnnGetConvolutionForwardAlgorithm (
180+
181+ // get workspace size able to allocate
182+ CUDNN_ENFORCE (platform::dynload::cudnnGetConvolutionForwardWorkspaceSize (
181183 handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
182- cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
183- workspace_size_limit, &algo));
184+ cudnn_output_desc, algo, &workspace_size_in_bytes));
185+
186+ // NOTE(zjl): cudnnGetConvolutionForwardAlgorithm_v7 cannot limit
187+ // workspace size. If the workspace size found by v7 exceeds the limit,
188+ // we should fallback to non-v7 method to find another algorithm.
189+ if (workspace_size_in_bytes > workspace_size_limit) {
190+ VLOG (1 ) << " Fallback to non-v7 method to find conv algorithm becasue "
191+ " the workspace size request("
192+ << workspace_size_in_bytes << " ) exceeds the limit("
193+ << workspace_size_limit << " )" ;
194+ #endif
195+ CUDNN_ENFORCE (platform::dynload::cudnnGetConvolutionForwardAlgorithm (
196+ handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
197+ cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
198+ workspace_size_limit, &algo));
199+ #if CUDNN_VERSION >= 7001
200+ } else {
201+ has_got_workspace_size = true ;
202+ }
184203#endif
185204
186205 VLOG (3 ) << " cuDNN forward algo " << algo;
@@ -219,10 +238,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
219238 " cuDNN exhaustive search doesn't support half float." );
220239 }
221240
222- // get workspace size able to allocate
223- CUDNN_ENFORCE (platform::dynload::cudnnGetConvolutionForwardWorkspaceSize (
224- handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
225- cudnn_output_desc, algo, &workspace_size_in_bytes));
241+ if (!has_got_workspace_size) {
242+ // get workspace size able to allocate
243+ CUDNN_ENFORCE (platform::dynload::cudnnGetConvolutionForwardWorkspaceSize (
244+ handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
245+ cudnn_output_desc, algo, &workspace_size_in_bytes));
246+ }
247+
226248 // It is possible for float16 on Volta GPU to allocate more memory than
227249 // the limit because the algo is overrided to use tensor core.
228250 PADDLE_ENFORCE_LE (workspace_size_in_bytes, workspace_size_limit,
@@ -366,6 +388,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
366388 auto x_dims = framework::vectorize (input->dims ());
367389 auto f_dims = framework::vectorize (filter->dims ());
368390 auto handle = dev_ctx.cudnn_handle ();
391+
392+ bool has_got_bwd_data_ws_size = false ;
369393 if (input_grad) {
370394 T* input_grad_data = input_grad->mutable_data <T>(ctx.GetPlace ());
371395 if (exhaustive_search) {
@@ -431,28 +455,49 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
431455 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
432456 data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
433457 }
434- # else
458+
435459 CUDNN_ENFORCE (
436- platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm (
437- handle, cudnn_filter_desc,
438- // dyDesc: Handle to the previously initialized input
439- // differential
440- // tensor descriptor.
441- cudnn_output_grad_desc, cudnn_conv_desc,
442- // dxDesc: Handle to the previously initialized output tensor
443- // descriptor.
444- cudnn_input_desc,
445- CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
446- workspace_size_limit, &data_algo));
460+ platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize (
461+ handle, cudnn_filter_desc, cudnn_output_grad_desc,
462+ cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
463+ auto new_workspace_size = std::max (workspace_size_in_bytes, tmp_size);
464+
465+ if (new_workspace_size > workspace_size_limit) {
466+ VLOG (1 ) << " Fallback to non-v7 method to find conv algorithm becasue "
467+ " the workspace size request("
468+ << new_workspace_size << " ) exceeds the limit("
469+ << workspace_size_limit << " )" ;
470+ #endif
471+ CUDNN_ENFORCE (
472+ platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm (
473+ handle, cudnn_filter_desc,
474+ // dyDesc: Handle to the previously initialized input
475+ // differential
476+ // tensor descriptor.
477+ cudnn_output_grad_desc, cudnn_conv_desc,
478+ // dxDesc: Handle to the previously initialized output tensor
479+ // descriptor.
480+ cudnn_input_desc,
481+ CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
482+ workspace_size_limit, &data_algo));
483+ #if CUDNN_VERSION >= 7001
484+ } else {
485+ workspace_size_in_bytes = new_workspace_size;
486+ has_got_bwd_data_ws_size = true ;
487+ }
447488#endif
448489 }
449- CUDNN_ENFORCE (
450- platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize (
451- handle, cudnn_filter_desc, cudnn_output_grad_desc,
452- cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
453- workspace_size_in_bytes = std::max (workspace_size_in_bytes, tmp_size);
490+
491+ if (!has_got_bwd_data_ws_size) {
492+ CUDNN_ENFORCE (
493+ platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize (
494+ handle, cudnn_filter_desc, cudnn_output_grad_desc,
495+ cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
496+ workspace_size_in_bytes = std::max (workspace_size_in_bytes, tmp_size);
497+ }
454498 }
455499
500+ bool has_got_bwd_filter_ws_size = false ;
456501 if (filter_grad) {
457502 T* filter_grad_data = filter_grad->mutable_data <T>(ctx.GetPlace ());
458503 if (exhaustive_search) {
@@ -495,22 +540,45 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
495540 cudnn_conv_desc, cudnn_filter_desc, kNUM_CUDNN_BWD_FILTER_ALGS ,
496541 &perf_count, perf_results.get ()));
497542 filter_algo = (perf_results.get ())[best_algo_idx].algo ;
498- # else
543+
499544 CUDNN_ENFORCE (
500- platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm (
545+ platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize (
501546 handle, cudnn_input_desc, cudnn_output_grad_desc,
502- cudnn_conv_desc, cudnn_filter_desc,
503- CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
504- workspace_size_limit, &filter_algo));
547+ cudnn_conv_desc, cudnn_filter_desc, filter_algo, &tmp_size));
548+ auto new_workspace_size = std::max (workspace_size_in_bytes, tmp_size);
549+
550+ if (new_workspace_size > workspace_size_limit) {
551+ VLOG (1 ) << " Fallback to non-v7 method to find conv algorithm becasue "
552+ " the workspace size request("
553+ << new_workspace_size << " ) exceeds the limit("
554+ << workspace_size_limit << " )" ;
505555#endif
556+ CUDNN_ENFORCE (
557+ platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm (
558+ handle, cudnn_input_desc, cudnn_output_grad_desc,
559+ cudnn_conv_desc, cudnn_filter_desc,
560+ CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
561+ workspace_size_limit, &filter_algo));
562+ #if CUDNN_VERSION >= 7001
563+ } else {
564+ workspace_size_in_bytes = new_workspace_size;
565+ has_got_bwd_filter_ws_size = true ;
566+ }
567+ #endif
568+ }
569+
570+ if (!has_got_bwd_filter_ws_size) {
571+ CUDNN_ENFORCE (
572+ platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize (
573+ handle, cudnn_input_desc, cudnn_output_grad_desc,
574+ cudnn_conv_desc, cudnn_filter_desc, filter_algo, &tmp_size));
575+ workspace_size_in_bytes = std::max (workspace_size_in_bytes, tmp_size);
506576 }
507- CUDNN_ENFORCE (
508- platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize (
509- handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
510- cudnn_filter_desc, filter_algo, &tmp_size));
511- workspace_size_in_bytes = std::max (workspace_size_in_bytes, tmp_size);
512577 }
513578
579+ PADDLE_ENFORCE_LE (workspace_size_in_bytes, workspace_size_limit,
580+ " workspace_size to be allocated exceeds the limit" );
581+
514582 // ------------------- cudnn conv workspace ---------------------
515583 if (!cudnn_workspace_ptr) {
516584 cudnn_workspace =
0 commit comments