Skip to content

Commit

Permalink
Performance improvements for depthwise convolutions in FP16 (pytorch#…
Browse files Browse the repository at this point in the history
…22302)

Summary:
This PR activates faster depthwise convolution kernels for Volta and Turing GPUs using cudnn >= 7600.
The script to benchmark the current PyTorch master branch and this PR branch can be found [here](https://gist.github.com/ptrblck/4590cf20721d8f43296c9903abd4a774).
(50 warmup iterations, 1000 iterations for timing)

I've used pytorch#3265 to create a similar benchmark and added a few additional setups.
Since the results are quite long, I've uploaded them in a spreadsheet [here](https://docs.google.com/spreadsheets/d/13ByXcqg7LQUr3DVG3XpLwnJ-CXg3GUZJ3puyTMw9n2I/edit?usp=sharing).
Times are given in ms per iteration.
We've benchmarked this PR on a DGX1 using V100 GPUs.

The current workload check in `check_cudnn_depthwise_workload` is quite long and can be moved to another file, if wanted.

CC ngimel (Thanks for the support while benchmarking it ;) )
Pull Request resolved: pytorch#22302

Differential Revision: D16115057

Pulled By: ezyang

fbshipit-source-id: bad184658518e73b4d6b849d77e408f5a7a757de
  • Loading branch information
ptrblck authored and facebook-github-bot committed Jul 9, 2019
1 parent 31d821e commit a3346e1
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 1 deletion.
14 changes: 14 additions & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,20 @@ bool CUDAHooks::supportsDilatedConvolutionWithCuDNN() const {
#endif
}

bool CUDAHooks::supportsDepthwiseConvolutionWithCuDNN() const {
#if AT_CUDNN_ENABLED()
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
// Check for Volta cores
if (prop->major >= 7) {
return true;
} else {
return false;
}
#else
return false;
#endif
}

long CUDAHooks::versionCuDNN() const {
#if AT_CUDNN_ENABLED()
return CUDNN_VERSION;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool compiledWithCuDNN() const override;
bool compiledWithMIOpen() const override;
bool supportsDilatedConvolutionWithCuDNN() const override;
bool supportsDepthwiseConvolutionWithCuDNN() const override;
long versionCuDNN() const override;
std::string showConfig() const override;
double batchnormMinEpsilonCuDNN() const override;
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/detail/CUDAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ struct CAFFE2_API CUDAHooksInterface {
return false;
}

virtual bool supportsDepthwiseConvolutionWithCuDNN() const {
return false;
}

virtual long versionCuDNN() const {
AT_ERROR("Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
Expand Down
147 changes: 146 additions & 1 deletion aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct ConvParams {
bool is_stride_neg() const;
void view1d_as_2d();
bool use_cudnn(const at::Tensor& input) const;
bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
bool use_miopen(const at::Tensor& input) const;
bool use_mkldnn(const at::Tensor& input) const;
bool use_nnpack(const at::Tensor& input) const;
Expand Down Expand Up @@ -187,6 +188,143 @@ auto ConvParams::is_depthwise(
weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels
}

// Check workload to activate fast depthwise FP16 cudnn conv kernels
bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) {
int w = input.size(3); // same as h
int ch = input.size(1);
int bs = input.size(0);
if (stride==1) {
if (w >= 7) {
// All batch sizes and nb_channels
if (w >= 112) {
return true;
}

// large nb_channels
if (ch >= 1024) {
if (w >= 56) {
return true;
} else if (bs >= 32) {
return true;
}
}

// batch_size specific
if (bs >= 128) {
if (ch >= 512) {
return true;
} else if (ch >= 64) {
if (w >= 14) {
return true;
}
} else if ((ch >= 32) && (w >=28)) {
return true;
}
} else if (bs >= 64) {
if ((ch >= 256) && (w >= 14)) {
return true;
} else if ((ch >= 32) && (w >= 28)) {
return true;
}
} else if (bs >= 32) {
if ((ch >= 256) && (w >= 14)) {
return true;
} else if ((ch >= 128) && (w >= 28)) {
return true;
} else if ((ch >= 32) && (w >= 56)) {
return true;
}
} else if (bs >= 16) {
if ((ch >= 1024) && (w >= 14)) {
return true;
}
if ((ch >= 256) && (w >= 28)) {
return true;
} else if ((ch >= 32) && (w >= 56)) {
return true;
}
} else if (bs >= 8) {
if ((ch >= 512) && (w >= 28)) {
return true;
} else if ((ch >= 64) && (w >= 56)) {
return true;
}
}
}
} else if (stride==2) {
if (ch < 256) {
return false;
}

if (w >= 7) {
if (bs >= 128) {
if (ch >= 1024) {
return true;
} else if ((ch >= 512) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 64) {
if ((ch >= 512) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 32) {
if ((ch >= 1024) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 16) {
if ((ch >= 512) && (w >= 28)) {
return true;
} else if (w >= 56) {
return true;
}
} else if (bs >= 8) {
if ((ch >= 1024) && (w >= 28)) {
return true;
} else if (w >= 56) {
return true;
}
} else if (bs >= 1) {
if ((ch >= 512) && (w >=112)) {
return true;
}
}
}
}
return false;
}

// Use cudnn for FP16 depthwise convolutions
auto ConvParams::use_cudnn_depthwise(
const at::Tensor& input, const at::Tensor& weight) const -> bool {
if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
long cudnn_version = detail::getCUDAHooks().versionCuDNN();
bool kernel_cond = (cudnn_version >= 7600 &&
use_cudnn(input) &&
input.scalar_type() == kHalf && // only for FP16
weight.scalar_type() == kHalf &&
is_depthwise(input, weight) &&
weight.size(2) == weight.size(3) && // only square kernels
input.size(2) >= 7 && // min width/height 7
!is_dilated() && // no dilation supported
stride[0] == stride[1] && // equal strides
((weight.size(3) == 3) || (weight.size(3) == 1)) &&
input.size(1) >= 32); // min 32 channels supported)
if (kernel_cond) {
return check_cudnn_depthwise_workload(input, stride[0]);
} else {
return false;
}
} else {
return false;
}
}

static void check_shape_forward(const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& bias,
const ConvParams& params, bool input_is_mkldnn) {
Expand Down Expand Up @@ -407,7 +545,14 @@ at::Tensor _convolution(
auto stride = params.stride;
auto padding = params.padding;
auto dilation = params.dilation;
output = at::thnn_conv_depthwise2d(input, weight, kernel_size, bias, stride, padding, dilation);
if (params.use_cudnn_depthwise(input, weight)) {
output = at::cudnn_convolution(
input, weight, bias,
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);

} else {
output = at::thnn_conv_depthwise2d(input, weight, kernel_size, bias, stride, padding, dilation);
}
} else if (params.use_cudnn(input)) {
TORCH_CHECK(input.type() == weight.type(),
"Input type (", input.type().toString(), ") and weight type (", weight.type().toString(),
Expand Down

0 comments on commit a3346e1

Please sign in to comment.