Skip to content

Commit

Permalink
Add cuda version check to skip building quantization ops for versions…
Browse files Browse the repository at this point in the history
… less than 8.0 (apache#10710)

* Add cuda version check to skip building quantization ops for version less 8

* Clearer error message on cuda and cudnn versions

* Trigger CI
  • Loading branch information
reminisce authored and piiswrong committed Apr 29, 2018
1 parent 92c373b commit 23d933b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
11 changes: 6 additions & 5 deletions src/operator/quantization/quantized_conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct QuantizedBiasAddKernel {
}
};

#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000
template<typename SrcType, typename DstType, typename CmpType>
class QuantizedCuDNNConvOp {
public:
Expand Down Expand Up @@ -260,7 +260,7 @@ class QuantizedCuDNNConvOp {
float alpha_ = 1.0f;
float beta_ = 0.0f;
}; // class QuantizedCuDNNConvOp
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000

void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -270,7 +270,7 @@ void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs,
const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
CHECK_EQ(param.kernel.ndim(), 2U)
<< "QuantizedConvForward<gpu> only supports 2D convolution for now";
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000
typedef QuantizedCuDNNConvOp<int8_t, float, int32_t> QuantizedConvOpInt8;
#if DMLC_CXX11_THREAD_LOCAL
static thread_local QuantizedConvOpInt8 op;
Expand All @@ -280,8 +280,9 @@ void QuantizedConvForwardGPU(const nnvm::NodeAttrs& attrs,
op.Init(param, ctx, {inputs[0].shape_, inputs[1].shape_}, {outputs[0].shape_});
op.Forward(ctx, inputs, req, outputs);
#else
LOG(FATAL) << "QuantizedConvForward<gpu> only supports cudnnConvolutionForward for now";
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
LOG(FATAL) << "QuantizedConvForward<gpu> only supports cudnnConvolutionForward "
"with CUDNN >= 6.0 and CUDA >= 8.0";
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000
}

NNVM_REGISTER_OP(_contrib_quantized_conv)
Expand Down
6 changes: 6 additions & 0 deletions src/operator/quantization/quantized_fully_connected.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
namespace mxnet {
namespace op {

#if CUDA_VERSION >= 8000
// value + bias_value * (range1 / limit_range1) * (limit_range2 / range2)
struct QuantizedBiasAddKernel {
MSHADOW_XINLINE static void Map(int i, size_t k, int32_t *out,
Expand All @@ -49,13 +50,15 @@ struct QuantizedBiasAddKernel {
float_for_one_out_quant;
}
};
#endif // CUDA_VERSION >= 8000

template<typename SrcType, typename DstType, typename CmpType>
void QuantizedFullyConnectedForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
#if CUDA_VERSION >= 8000
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
using namespace mshadow;
using namespace mxnet_op;
Expand Down Expand Up @@ -113,6 +116,9 @@ void QuantizedFullyConnectedForwardGPU(const nnvm::NodeAttrs& attrs,
outputs[1].dptr<float>(), outputs[2].dptr<float>(),
inputs[7].dptr<float>(), inputs[8].dptr<float>());
}
#else
LOG(FATAL) << "QuantizedFullyConnectedForwardGPU only supports CUDA >= 8.0";
#endif // CUDA_VERSION >= 8000
}

NNVM_REGISTER_OP(_contrib_quantized_fully_connected)
Expand Down
11 changes: 6 additions & 5 deletions src/operator/quantization/quantized_pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
namespace mxnet {
namespace op {

#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000
template<typename DType>
class QuantizedCuDNNPoolingOp {
public:
Expand Down Expand Up @@ -115,7 +115,7 @@ class QuantizedCuDNNPoolingOp {
cudnnTensorDescriptor_t out_desc_;
cudnnPoolingDescriptor_t pool_desc_;
}; // class QuantizedCuDNNPoolingOp
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000

void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -125,7 +125,7 @@ void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs,
const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
CHECK_EQ(param.kernel.ndim(), 2U)
<< "QuantizedPoolingForward<gpu> only supports 2D convolution for now";
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000
#if DMLC_CXX11_THREAD_LOCAL
static thread_local QuantizedCuDNNPoolingOp<int8_t> op;
#else
Expand All @@ -134,8 +134,9 @@ void QuantizedPoolingForwardGPU(const nnvm::NodeAttrs& attrs,
op.Init(param, {inputs[0].shape_}, {outputs[0].shape_});
op.Forward(ctx.get_stream<gpu>(), inputs, req, outputs);
#else
LOG(FATAL) << "QuantizedPoolingForward<gpu> only supports cudnnPoolingForward for now";
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6
LOG(FATAL) << "QuantizedPoolingForward<gpu> only supports cudnnPoolingForward "
"with CUDNN >= 6.0 and CUDA >= 8.0";
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 6 && CUDA_VERSION >= 8000
}

NNVM_REGISTER_OP(_contrib_quantized_pooling)
Expand Down

0 comments on commit 23d933b

Please sign in to comment.