Skip to content

Commit 03fb5da

Browse files
timshen91tensorflower-gardener
authored andcommitted
Simplify PrepareForConvolution:
* Change the virtual functions overloaded on DeviceMemory<T> to a single function DeviceMemoryBase. * Create a non-virtual, template wrapper that takes DeviceMemory<T>. PiperOrigin-RevId: 228818212
1 parent ed7a8b4 commit 03fb5da

File tree

5 files changed

+216
-631
lines changed

5 files changed

+216
-631
lines changed

tensorflow/stream_executor/cuda/cuda_dnn.cc

Lines changed: 60 additions & 265 deletions
Original file line numberDiff line numberDiff line change
@@ -995,9 +995,11 @@ cudnnDataType_t ToCudnnDataType(
995995
dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
996996
switch (data_type) {
997997
case dnn::DataType::kFloat:
998+
return CUDNN_DATA_FLOAT;
998999
case dnn::DataType::kDouble:
1000+
return CUDNN_DATA_DOUBLE;
9991001
case dnn::DataType::kHalf:
1000-
return static_cast<cudnnDataType_t>(data_type);
1002+
return CUDNN_DATA_HALF;
10011003
case dnn::DataType::kInt8:
10021004
return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4
10031005
: CUDNN_DATA_INT8;
@@ -1008,6 +1010,15 @@ cudnnDataType_t ToCudnnDataType(
10081010
}
10091011
}
10101012

1013+
cudnnDataType_t ToCudnnDataType(dnn::DataType data_type,
1014+
dnn::FilterLayout filter_layout) {
1015+
if (data_type == dnn::DataType::kInt8 &&
1016+
filter_layout == dnn::FilterLayout::kOutputInputYX4) {
1017+
return CUDNN_DATA_INT8x4;
1018+
}
1019+
return ToCudnnDataType(data_type);
1020+
}
1021+
10111022
template <typename T>
10121023
cudnnDataType_t GetCudnnDataType(
10131024
dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
@@ -2815,30 +2826,60 @@ void LogCudaProto(const dnn::ConvolutionProto& conv, float profile_time_ms,
28152826

28162827
} // namespace
28172828

2818-
template <class T>
2819-
port::Status CudnnSupport::PrepareForConvolutionImpl(
2820-
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
2821-
const DeviceMemory<T>& input_data,
2829+
port::Status CudnnSupport::DoPrepareForConvolution(
2830+
dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2831+
const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
28222832
const dnn::FilterDescriptor& filter_descriptor,
2823-
const DeviceMemory<T>& filter_data,
2833+
DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2834+
DeviceMemoryBase output_data,
28242835
const dnn::ConvolutionDescriptor& convolution_descriptor,
2825-
const dnn::BatchDescriptor& output_descriptor, DeviceMemory<T>* output_data,
2826-
dnn::DataType accumulator_type, ScratchAllocator* scratch_allocator,
28272836
const dnn::AlgorithmConfig& algorithm_config,
2828-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
2829-
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
2830-
CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
2831-
CudnnTensorDescriptor output_nd(output_descriptor, cudnn_type);
2832-
CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
2833-
CudnnConvolutionDescriptor conv(convolution_descriptor,
2834-
ToCudnnDataType(accumulator_type));
2837+
ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2838+
DeviceMemory<uint8>* scratch_memory) {
2839+
CudnnTensorDescriptor input_nd(
2840+
input_descriptor,
2841+
ToCudnnDataType(element_type, input_descriptor.layout()));
2842+
CudnnFilterDescriptor filter_nd(
2843+
filter_descriptor,
2844+
ToCudnnDataType(element_type, filter_descriptor.layout()));
2845+
CudnnTensorDescriptor output_nd(
2846+
output_descriptor,
2847+
ToCudnnDataType(element_type, output_descriptor.layout()));
2848+
CudnnConvolutionDescriptor conv(
2849+
convolution_descriptor,
2850+
ToCudnnDataType(GetConvAccumulatorType(element_type)));
28352851

28362852
auto cudnn = cudnn_->GetHandle(parent_, stream);
28372853

2838-
SE_ASSIGN_OR_RETURN(*algorithm_desc,
2839-
GetCudnnConvolutionForwardAlgorithm(
2840-
stream, cudnn, algorithm_config, input_nd, filter,
2841-
conv, output_nd, scratch_allocator, scratch_memory));
2854+
switch (kind) {
2855+
case dnn::ConvolutionKind::FORWARD: {
2856+
SE_ASSIGN_OR_RETURN(
2857+
*algorithm_desc,
2858+
GetCudnnConvolutionForwardAlgorithm(
2859+
stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2860+
output_nd, scratch_allocator, scratch_memory));
2861+
break;
2862+
}
2863+
case dnn::ConvolutionKind::BACKWARD_DATA: {
2864+
SE_ASSIGN_OR_RETURN(
2865+
*algorithm_desc,
2866+
GetCudnnConvolutionBackwardDataAlgorithm(
2867+
stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2868+
output_nd, scratch_allocator, scratch_memory));
2869+
break;
2870+
}
2871+
case dnn::ConvolutionKind::BACKWARD_FILTER: {
2872+
SE_ASSIGN_OR_RETURN(
2873+
*algorithm_desc,
2874+
GetCudnnConvolutionBackwardFilterAlgorithm(
2875+
stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2876+
output_nd, scratch_allocator, scratch_memory));
2877+
break;
2878+
}
2879+
default:
2880+
return port::InternalError(
2881+
absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
2882+
}
28422883

28432884
return port::Status::OK();
28442885
}
@@ -3351,64 +3392,6 @@ port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
33513392
return port::Status::OK();
33523393
}
33533394

3354-
bool CudnnSupport::PrepareForConvolution(
3355-
Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
3356-
const DeviceMemory<float>& input_data,
3357-
const dnn::FilterDescriptor& filter_descriptor,
3358-
const DeviceMemory<float>& filter_data,
3359-
const dnn::ConvolutionDescriptor& convolution_descriptor,
3360-
const dnn::BatchDescriptor& output_descriptor,
3361-
DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3362-
const dnn::AlgorithmConfig& algorithm_config,
3363-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3364-
return IsStatusOk(PrepareForConvolutionImpl<float>(
3365-
stream, batch_descriptor, input_data, filter_descriptor,
3366-
filter_data, convolution_descriptor, output_descriptor,
3367-
output_data, dnn::DataType::kFloat, scratch_allocator,
3368-
algorithm_config, algorithm_desc, scratch_memory),
3369-
/*report_error=*/true);
3370-
}
3371-
3372-
bool CudnnSupport::PrepareForConvolution(
3373-
Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
3374-
const DeviceMemory<double>& input_data,
3375-
const dnn::FilterDescriptor& filter_descriptor,
3376-
const DeviceMemory<double>& filter_data,
3377-
const dnn::ConvolutionDescriptor& convolution_descriptor,
3378-
const dnn::BatchDescriptor& output_descriptor,
3379-
DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
3380-
const dnn::AlgorithmConfig& algorithm_config,
3381-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3382-
return IsStatusOk(PrepareForConvolutionImpl<double>(
3383-
stream, batch_descriptor, input_data, filter_descriptor,
3384-
filter_data, convolution_descriptor, output_descriptor,
3385-
output_data, dnn::DataType::kDouble, scratch_allocator,
3386-
algorithm_config, algorithm_desc, scratch_memory),
3387-
/*report_error=*/true);
3388-
}
3389-
3390-
bool CudnnSupport::PrepareForConvolution(
3391-
Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
3392-
const DeviceMemory<Eigen::half>& input_data,
3393-
const dnn::FilterDescriptor& filter_descriptor,
3394-
const DeviceMemory<Eigen::half>& filter_data,
3395-
const dnn::ConvolutionDescriptor& convolution_descriptor,
3396-
const dnn::BatchDescriptor& output_descriptor,
3397-
DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
3398-
const dnn::AlgorithmConfig& algorithm_config,
3399-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3400-
dnn::DataType acc_type =
3401-
CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3402-
? dnn::DataType::kFloat
3403-
: dnn::DataType::kHalf;
3404-
return IsStatusOk(
3405-
PrepareForConvolutionImpl<Eigen::half>(
3406-
stream, batch_descriptor, input_data, filter_descriptor, filter_data,
3407-
convolution_descriptor, output_descriptor, output_data, acc_type,
3408-
scratch_allocator, algorithm_config, algorithm_desc, scratch_memory),
3409-
/*report_error=*/true);
3410-
}
3411-
34123395
bool CudnnSupport::DoConvolve(
34133396
Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
34143397
const DeviceMemory<float>& input_data,
@@ -3592,36 +3575,6 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
35923575
return IsStatusOk(status, /*report_error=*/true);
35933576
}
35943577

3595-
template <class T>
3596-
port::Status CudnnSupport::PrepareForConvolutionBackwardDataImpl(
3597-
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
3598-
const DeviceMemory<T>& filter_data,
3599-
const dnn::BatchDescriptor& output_descriptor,
3600-
DeviceMemory<T> backward_output_data,
3601-
const dnn::ConvolutionDescriptor& convolution_descriptor,
3602-
const dnn::BatchDescriptor& input_descriptor,
3603-
DeviceMemory<T>* backward_input_data, dnn::DataType accumulator_type,
3604-
ScratchAllocator* scratch_allocator,
3605-
const dnn::AlgorithmConfig& algorithm_config,
3606-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3607-
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
3608-
auto cudnn = cudnn_->GetHandle(parent_, stream);
3609-
3610-
CudnnTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
3611-
CudnnTensorDescriptor in_back_nd(input_descriptor, cudnn_type);
3612-
CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
3613-
CudnnConvolutionDescriptor conv(convolution_descriptor,
3614-
ToCudnnDataType(accumulator_type));
3615-
3616-
SE_ASSIGN_OR_RETURN(
3617-
*algorithm_desc,
3618-
GetCudnnConvolutionBackwardDataAlgorithm(
3619-
stream, cudnn, algorithm_config, in_back_nd, filter, conv,
3620-
out_back_nd, scratch_allocator, scratch_memory));
3621-
3622-
return port::Status::OK();
3623-
}
3624-
36253578
template <class T>
36263579
port::Status CudnnSupport::DoConvolveBackwardDataImpl(
36273580
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
@@ -3722,70 +3675,6 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
37223675
return port::Status::OK();
37233676
}
37243677

3725-
bool CudnnSupport::PrepareForConvolutionBackwardData(
3726-
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
3727-
const DeviceMemory<double>& filter_data,
3728-
const dnn::BatchDescriptor& output_descriptor,
3729-
DeviceMemory<double> backward_output_data,
3730-
const dnn::ConvolutionDescriptor& convolution_descriptor,
3731-
const dnn::BatchDescriptor& input_descriptor,
3732-
DeviceMemory<double>* backward_input_data,
3733-
ScratchAllocator* scratch_allocator,
3734-
const dnn::AlgorithmConfig& algorithm_config,
3735-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3736-
return IsStatusOk(
3737-
PrepareForConvolutionBackwardDataImpl(
3738-
stream, filter_descriptor, filter_data, output_descriptor,
3739-
backward_output_data, convolution_descriptor, input_descriptor,
3740-
backward_input_data, dnn::DataType::kDouble, scratch_allocator,
3741-
algorithm_config, algorithm_desc, scratch_memory),
3742-
/*report_error=*/true);
3743-
}
3744-
3745-
bool CudnnSupport::PrepareForConvolutionBackwardData(
3746-
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
3747-
const DeviceMemory<float>& filter_data,
3748-
const dnn::BatchDescriptor& output_descriptor,
3749-
DeviceMemory<float> backward_output_data,
3750-
const dnn::ConvolutionDescriptor& convolution_descriptor,
3751-
const dnn::BatchDescriptor& input_descriptor,
3752-
DeviceMemory<float>* backward_input_data,
3753-
ScratchAllocator* scratch_allocator,
3754-
const dnn::AlgorithmConfig& algorithm_config,
3755-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3756-
return IsStatusOk(
3757-
PrepareForConvolutionBackwardDataImpl(
3758-
stream, filter_descriptor, filter_data, output_descriptor,
3759-
backward_output_data, convolution_descriptor, input_descriptor,
3760-
backward_input_data, dnn::DataType::kFloat, scratch_allocator,
3761-
algorithm_config, algorithm_desc, scratch_memory),
3762-
/*report_error=*/true);
3763-
}
3764-
3765-
bool CudnnSupport::PrepareForConvolutionBackwardData(
3766-
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
3767-
const DeviceMemory<Eigen::half>& filter_data,
3768-
const dnn::BatchDescriptor& output_descriptor,
3769-
DeviceMemory<Eigen::half> backward_output_data,
3770-
const dnn::ConvolutionDescriptor& convolution_descriptor,
3771-
const dnn::BatchDescriptor& input_descriptor,
3772-
DeviceMemory<Eigen::half>* backward_input_data,
3773-
ScratchAllocator* scratch_allocator,
3774-
const dnn::AlgorithmConfig& algorithm_config,
3775-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3776-
dnn::DataType acc_type =
3777-
CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
3778-
? dnn::DataType::kFloat
3779-
: dnn::DataType::kHalf;
3780-
return IsStatusOk(
3781-
PrepareForConvolutionBackwardDataImpl(
3782-
stream, filter_descriptor, filter_data, output_descriptor,
3783-
backward_output_data, convolution_descriptor, input_descriptor,
3784-
backward_input_data, acc_type, scratch_allocator, algorithm_config,
3785-
algorithm_desc, scratch_memory),
3786-
/*report_error=*/true);
3787-
}
3788-
37893678
bool CudnnSupport::DoConvolveBackwardData(
37903679
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
37913680
const DeviceMemory<double>& filter_data,
@@ -3846,36 +3735,6 @@ bool CudnnSupport::DoConvolveBackwardData(
38463735
/*report_error=*/!output_profile_result);
38473736
}
38483737

3849-
template <class T>
3850-
port::Status CudnnSupport::PrepareForConvolutionBackwardFilterImpl(
3851-
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3852-
const DeviceMemory<T>& input_data,
3853-
const dnn::BatchDescriptor& output_descriptor,
3854-
DeviceMemory<T> backward_output_data,
3855-
const dnn::ConvolutionDescriptor& convolution_descriptor,
3856-
const dnn::FilterDescriptor& filter_descriptor,
3857-
DeviceMemory<T>* backward_filter_data, dnn::DataType accumulator_type,
3858-
ScratchAllocator* scratch_allocator,
3859-
const dnn::AlgorithmConfig& algorithm_config,
3860-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3861-
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
3862-
auto cudnn = cudnn_->GetHandle(parent_, stream);
3863-
3864-
CudnnTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
3865-
CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
3866-
CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
3867-
CudnnConvolutionDescriptor conv(convolution_descriptor,
3868-
ToCudnnDataType(accumulator_type));
3869-
3870-
SE_ASSIGN_OR_RETURN(
3871-
*algorithm_desc,
3872-
GetCudnnConvolutionBackwardFilterAlgorithm(
3873-
stream, cudnn, algorithm_config, input_nd, filter, conv, out_back_nd,
3874-
scratch_allocator, scratch_memory));
3875-
3876-
return port::Status::OK();
3877-
}
3878-
38793738
template <class T>
38803739
port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
38813740
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
@@ -4013,70 +3872,6 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
40133872
return port::Status::OK();
40143873
}
40153874

4016-
bool CudnnSupport::PrepareForConvolutionBackwardFilter(
4017-
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4018-
const DeviceMemory<double>& input_data,
4019-
const dnn::BatchDescriptor& output_descriptor,
4020-
DeviceMemory<double> backward_output_data,
4021-
const dnn::ConvolutionDescriptor& convolution_descriptor,
4022-
const dnn::FilterDescriptor& filter_descriptor,
4023-
DeviceMemory<double>* backward_filter_data,
4024-
ScratchAllocator* scratch_allocator,
4025-
const dnn::AlgorithmConfig& algorithm_config,
4026-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
4027-
return IsStatusOk(
4028-
PrepareForConvolutionBackwardFilterImpl(
4029-
stream, input_descriptor, input_data, output_descriptor,
4030-
backward_output_data, convolution_descriptor, filter_descriptor,
4031-
backward_filter_data, dnn::DataType::kDouble, scratch_allocator,
4032-
algorithm_config, algorithm_desc, scratch_memory),
4033-
/*report_error=*/true);
4034-
}
4035-
4036-
bool CudnnSupport::PrepareForConvolutionBackwardFilter(
4037-
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4038-
const DeviceMemory<float>& input_data,
4039-
const dnn::BatchDescriptor& output_descriptor,
4040-
DeviceMemory<float> backward_output_data,
4041-
const dnn::ConvolutionDescriptor& convolution_descriptor,
4042-
const dnn::FilterDescriptor& filter_descriptor,
4043-
DeviceMemory<float>* backward_filter_data,
4044-
ScratchAllocator* scratch_allocator,
4045-
const dnn::AlgorithmConfig& algorithm_config,
4046-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
4047-
return IsStatusOk(
4048-
PrepareForConvolutionBackwardFilterImpl(
4049-
stream, input_descriptor, input_data, output_descriptor,
4050-
backward_output_data, convolution_descriptor, filter_descriptor,
4051-
backward_filter_data, dnn::DataType::kFloat, scratch_allocator,
4052-
algorithm_config, algorithm_desc, scratch_memory),
4053-
/*report_error=*/true);
4054-
}
4055-
4056-
bool CudnnSupport::PrepareForConvolutionBackwardFilter(
4057-
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4058-
const DeviceMemory<Eigen::half>& input_data,
4059-
const dnn::BatchDescriptor& output_descriptor,
4060-
DeviceMemory<Eigen::half> backward_output_data,
4061-
const dnn::ConvolutionDescriptor& convolution_descriptor,
4062-
const dnn::FilterDescriptor& filter_descriptor,
4063-
DeviceMemory<Eigen::half>* backward_filter_data,
4064-
ScratchAllocator* scratch_allocator,
4065-
const dnn::AlgorithmConfig& algorithm_config,
4066-
dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
4067-
dnn::DataType acc_type =
4068-
CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
4069-
? dnn::DataType::kFloat
4070-
: dnn::DataType::kHalf;
4071-
return IsStatusOk(
4072-
PrepareForConvolutionBackwardFilterImpl(
4073-
stream, input_descriptor, input_data, output_descriptor,
4074-
backward_output_data, convolution_descriptor, filter_descriptor,
4075-
backward_filter_data, acc_type, scratch_allocator, algorithm_config,
4076-
algorithm_desc, scratch_memory),
4077-
/*report_error=*/true);
4078-
}
4079-
40803875
bool CudnnSupport::DoConvolveBackwardFilter(
40813876
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
40823877
const DeviceMemory<double>& input_data,

0 commit comments

Comments
 (0)