Skip to content

Commit

Permalink
deconv cudnn
Browse files Browse the repository at this point in the history
  • Loading branch information
zchen0211 committed Nov 2, 2017
1 parent 7e34b8e commit 2d956b8
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions paddle/operators/conv2d_transpose_cudnn_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout;
using CUDADeviceContext = platform::CUDADeviceContext;

static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024;
static constexpr size_t kConvCudnnWorkspaceLimitBytes = 1024 * 1024 * 1024;

template <typename T>
class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
Expand Down Expand Up @@ -71,7 +71,7 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv workspace ---------------------
void* cudnn_workspace = nullptr;
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
}
Expand Down Expand Up @@ -125,6 +125,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {

std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
// cudnn v5 does not support dilations
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");

Expand Down Expand Up @@ -153,7 +154,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnnConvolutionBwdFilterAlgo_t filter_algo;
size_t bwd_filter_ws_size, fwd_ws_size;
size_t workspace_size_in_bytes = 0;
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
}
Expand Down

0 comments on commit 2d956b8

Please sign in to comment.