@@ -995,9 +995,11 @@ cudnnDataType_t ToCudnnDataType(
995
995
dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX ) {
996
996
switch (data_type) {
997
997
case dnn::DataType::kFloat :
998
+ return CUDNN_DATA_FLOAT;
998
999
case dnn::DataType::kDouble :
1000
+ return CUDNN_DATA_DOUBLE;
999
1001
case dnn::DataType::kHalf :
1000
- return static_cast <cudnnDataType_t>(data_type) ;
1002
+ return CUDNN_DATA_HALF ;
1001
1003
case dnn::DataType::kInt8 :
1002
1004
return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4
1003
1005
: CUDNN_DATA_INT8;
@@ -1008,6 +1010,15 @@ cudnnDataType_t ToCudnnDataType(
1008
1010
}
1009
1011
}
1010
1012
1013
+ cudnnDataType_t ToCudnnDataType (dnn::DataType data_type,
1014
+ dnn::FilterLayout filter_layout) {
1015
+ if (data_type == dnn::DataType::kInt8 &&
1016
+ filter_layout == dnn::FilterLayout::kOutputInputYX4 ) {
1017
+ return CUDNN_DATA_INT8x4;
1018
+ }
1019
+ return ToCudnnDataType (data_type);
1020
+ }
1021
+
1011
1022
template <typename T>
1012
1023
cudnnDataType_t GetCudnnDataType (
1013
1024
dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX ) {
@@ -2815,30 +2826,60 @@ void LogCudaProto(const dnn::ConvolutionProto& conv, float profile_time_ms,
2815
2826
2816
2827
} // namespace
2817
2828
2818
- template <class T >
2819
- port::Status CudnnSupport::PrepareForConvolutionImpl (
2820
- Stream* stream, const dnn::BatchDescriptor& input_descriptor,
2821
- const DeviceMemory<T>& input_data,
2829
+ port::Status CudnnSupport::DoPrepareForConvolution (
2830
+ dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2831
+ const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2822
2832
const dnn::FilterDescriptor& filter_descriptor,
2823
- const DeviceMemory<T>& filter_data,
2833
+ DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2834
+ DeviceMemoryBase output_data,
2824
2835
const dnn::ConvolutionDescriptor& convolution_descriptor,
2825
- const dnn::BatchDescriptor& output_descriptor, DeviceMemory<T>* output_data,
2826
- dnn::DataType accumulator_type, ScratchAllocator* scratch_allocator,
2827
2836
const dnn::AlgorithmConfig& algorithm_config,
2828
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
2829
- cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
2830
- CudnnTensorDescriptor input_nd (input_descriptor, cudnn_type);
2831
- CudnnTensorDescriptor output_nd (output_descriptor, cudnn_type);
2832
- CudnnFilterDescriptor filter (filter_descriptor, cudnn_type);
2833
- CudnnConvolutionDescriptor conv (convolution_descriptor,
2834
- ToCudnnDataType (accumulator_type));
2837
+ ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2838
+ DeviceMemory<uint8>* scratch_memory) {
2839
+ CudnnTensorDescriptor input_nd (
2840
+ input_descriptor,
2841
+ ToCudnnDataType (element_type, input_descriptor.layout ()));
2842
+ CudnnFilterDescriptor filter_nd (
2843
+ filter_descriptor,
2844
+ ToCudnnDataType (element_type, filter_descriptor.layout ()));
2845
+ CudnnTensorDescriptor output_nd (
2846
+ output_descriptor,
2847
+ ToCudnnDataType (element_type, output_descriptor.layout ()));
2848
+ CudnnConvolutionDescriptor conv (
2849
+ convolution_descriptor,
2850
+ ToCudnnDataType (GetConvAccumulatorType (element_type)));
2835
2851
2836
2852
auto cudnn = cudnn_->GetHandle (parent_, stream);
2837
2853
2838
- SE_ASSIGN_OR_RETURN (*algorithm_desc,
2839
- GetCudnnConvolutionForwardAlgorithm (
2840
- stream, cudnn, algorithm_config, input_nd, filter,
2841
- conv, output_nd, scratch_allocator, scratch_memory));
2854
+ switch (kind) {
2855
+ case dnn::ConvolutionKind::FORWARD: {
2856
+ SE_ASSIGN_OR_RETURN (
2857
+ *algorithm_desc,
2858
+ GetCudnnConvolutionForwardAlgorithm (
2859
+ stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2860
+ output_nd, scratch_allocator, scratch_memory));
2861
+ break ;
2862
+ }
2863
+ case dnn::ConvolutionKind::BACKWARD_DATA: {
2864
+ SE_ASSIGN_OR_RETURN (
2865
+ *algorithm_desc,
2866
+ GetCudnnConvolutionBackwardDataAlgorithm (
2867
+ stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2868
+ output_nd, scratch_allocator, scratch_memory));
2869
+ break ;
2870
+ }
2871
+ case dnn::ConvolutionKind::BACKWARD_FILTER: {
2872
+ SE_ASSIGN_OR_RETURN (
2873
+ *algorithm_desc,
2874
+ GetCudnnConvolutionBackwardFilterAlgorithm (
2875
+ stream, cudnn, algorithm_config, input_nd, filter_nd, conv,
2876
+ output_nd, scratch_allocator, scratch_memory));
2877
+ break ;
2878
+ }
2879
+ default :
2880
+ return port::InternalError (
2881
+ absl::StrCat (" Unexpected convolution kind " , static_cast <int >(kind)));
2882
+ }
2842
2883
2843
2884
return port::Status::OK ();
2844
2885
}
@@ -3351,64 +3392,6 @@ port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
3351
3392
return port::Status::OK ();
3352
3393
}
3353
3394
3354
- bool CudnnSupport::PrepareForConvolution (
3355
- Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
3356
- const DeviceMemory<float >& input_data,
3357
- const dnn::FilterDescriptor& filter_descriptor,
3358
- const DeviceMemory<float >& filter_data,
3359
- const dnn::ConvolutionDescriptor& convolution_descriptor,
3360
- const dnn::BatchDescriptor& output_descriptor,
3361
- DeviceMemory<float >* output_data, ScratchAllocator* scratch_allocator,
3362
- const dnn::AlgorithmConfig& algorithm_config,
3363
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3364
- return IsStatusOk (PrepareForConvolutionImpl<float >(
3365
- stream, batch_descriptor, input_data, filter_descriptor,
3366
- filter_data, convolution_descriptor, output_descriptor,
3367
- output_data, dnn::DataType::kFloat , scratch_allocator,
3368
- algorithm_config, algorithm_desc, scratch_memory),
3369
- /* report_error=*/ true );
3370
- }
3371
-
3372
- bool CudnnSupport::PrepareForConvolution (
3373
- Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
3374
- const DeviceMemory<double >& input_data,
3375
- const dnn::FilterDescriptor& filter_descriptor,
3376
- const DeviceMemory<double >& filter_data,
3377
- const dnn::ConvolutionDescriptor& convolution_descriptor,
3378
- const dnn::BatchDescriptor& output_descriptor,
3379
- DeviceMemory<double >* output_data, ScratchAllocator* scratch_allocator,
3380
- const dnn::AlgorithmConfig& algorithm_config,
3381
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3382
- return IsStatusOk (PrepareForConvolutionImpl<double >(
3383
- stream, batch_descriptor, input_data, filter_descriptor,
3384
- filter_data, convolution_descriptor, output_descriptor,
3385
- output_data, dnn::DataType::kDouble , scratch_allocator,
3386
- algorithm_config, algorithm_desc, scratch_memory),
3387
- /* report_error=*/ true );
3388
- }
3389
-
3390
- bool CudnnSupport::PrepareForConvolution (
3391
- Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
3392
- const DeviceMemory<Eigen::half>& input_data,
3393
- const dnn::FilterDescriptor& filter_descriptor,
3394
- const DeviceMemory<Eigen::half>& filter_data,
3395
- const dnn::ConvolutionDescriptor& convolution_descriptor,
3396
- const dnn::BatchDescriptor& output_descriptor,
3397
- DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
3398
- const dnn::AlgorithmConfig& algorithm_config,
3399
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3400
- dnn::DataType acc_type =
3401
- CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled ()
3402
- ? dnn::DataType::kFloat
3403
- : dnn::DataType::kHalf ;
3404
- return IsStatusOk (
3405
- PrepareForConvolutionImpl<Eigen::half>(
3406
- stream, batch_descriptor, input_data, filter_descriptor, filter_data,
3407
- convolution_descriptor, output_descriptor, output_data, acc_type,
3408
- scratch_allocator, algorithm_config, algorithm_desc, scratch_memory),
3409
- /* report_error=*/ true );
3410
- }
3411
-
3412
3395
bool CudnnSupport::DoConvolve (
3413
3396
Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
3414
3397
const DeviceMemory<float >& input_data,
@@ -3592,36 +3575,6 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
3592
3575
return IsStatusOk (status, /* report_error=*/ true );
3593
3576
}
3594
3577
3595
- template <class T >
3596
- port::Status CudnnSupport::PrepareForConvolutionBackwardDataImpl (
3597
- Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
3598
- const DeviceMemory<T>& filter_data,
3599
- const dnn::BatchDescriptor& output_descriptor,
3600
- DeviceMemory<T> backward_output_data,
3601
- const dnn::ConvolutionDescriptor& convolution_descriptor,
3602
- const dnn::BatchDescriptor& input_descriptor,
3603
- DeviceMemory<T>* backward_input_data, dnn::DataType accumulator_type,
3604
- ScratchAllocator* scratch_allocator,
3605
- const dnn::AlgorithmConfig& algorithm_config,
3606
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3607
- cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
3608
- auto cudnn = cudnn_->GetHandle (parent_, stream);
3609
-
3610
- CudnnTensorDescriptor out_back_nd (output_descriptor, cudnn_type);
3611
- CudnnTensorDescriptor in_back_nd (input_descriptor, cudnn_type);
3612
- CudnnFilterDescriptor filter (filter_descriptor, cudnn_type);
3613
- CudnnConvolutionDescriptor conv (convolution_descriptor,
3614
- ToCudnnDataType (accumulator_type));
3615
-
3616
- SE_ASSIGN_OR_RETURN (
3617
- *algorithm_desc,
3618
- GetCudnnConvolutionBackwardDataAlgorithm (
3619
- stream, cudnn, algorithm_config, in_back_nd, filter, conv,
3620
- out_back_nd, scratch_allocator, scratch_memory));
3621
-
3622
- return port::Status::OK ();
3623
- }
3624
-
3625
3578
template <class T >
3626
3579
port::Status CudnnSupport::DoConvolveBackwardDataImpl (
3627
3580
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
@@ -3722,70 +3675,6 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
3722
3675
return port::Status::OK ();
3723
3676
}
3724
3677
3725
- bool CudnnSupport::PrepareForConvolutionBackwardData (
3726
- Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
3727
- const DeviceMemory<double >& filter_data,
3728
- const dnn::BatchDescriptor& output_descriptor,
3729
- DeviceMemory<double > backward_output_data,
3730
- const dnn::ConvolutionDescriptor& convolution_descriptor,
3731
- const dnn::BatchDescriptor& input_descriptor,
3732
- DeviceMemory<double >* backward_input_data,
3733
- ScratchAllocator* scratch_allocator,
3734
- const dnn::AlgorithmConfig& algorithm_config,
3735
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3736
- return IsStatusOk (
3737
- PrepareForConvolutionBackwardDataImpl (
3738
- stream, filter_descriptor, filter_data, output_descriptor,
3739
- backward_output_data, convolution_descriptor, input_descriptor,
3740
- backward_input_data, dnn::DataType::kDouble , scratch_allocator,
3741
- algorithm_config, algorithm_desc, scratch_memory),
3742
- /* report_error=*/ true );
3743
- }
3744
-
3745
- bool CudnnSupport::PrepareForConvolutionBackwardData (
3746
- Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
3747
- const DeviceMemory<float >& filter_data,
3748
- const dnn::BatchDescriptor& output_descriptor,
3749
- DeviceMemory<float > backward_output_data,
3750
- const dnn::ConvolutionDescriptor& convolution_descriptor,
3751
- const dnn::BatchDescriptor& input_descriptor,
3752
- DeviceMemory<float >* backward_input_data,
3753
- ScratchAllocator* scratch_allocator,
3754
- const dnn::AlgorithmConfig& algorithm_config,
3755
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3756
- return IsStatusOk (
3757
- PrepareForConvolutionBackwardDataImpl (
3758
- stream, filter_descriptor, filter_data, output_descriptor,
3759
- backward_output_data, convolution_descriptor, input_descriptor,
3760
- backward_input_data, dnn::DataType::kFloat , scratch_allocator,
3761
- algorithm_config, algorithm_desc, scratch_memory),
3762
- /* report_error=*/ true );
3763
- }
3764
-
3765
- bool CudnnSupport::PrepareForConvolutionBackwardData (
3766
- Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
3767
- const DeviceMemory<Eigen::half>& filter_data,
3768
- const dnn::BatchDescriptor& output_descriptor,
3769
- DeviceMemory<Eigen::half> backward_output_data,
3770
- const dnn::ConvolutionDescriptor& convolution_descriptor,
3771
- const dnn::BatchDescriptor& input_descriptor,
3772
- DeviceMemory<Eigen::half>* backward_input_data,
3773
- ScratchAllocator* scratch_allocator,
3774
- const dnn::AlgorithmConfig& algorithm_config,
3775
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3776
- dnn::DataType acc_type =
3777
- CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled ()
3778
- ? dnn::DataType::kFloat
3779
- : dnn::DataType::kHalf ;
3780
- return IsStatusOk (
3781
- PrepareForConvolutionBackwardDataImpl (
3782
- stream, filter_descriptor, filter_data, output_descriptor,
3783
- backward_output_data, convolution_descriptor, input_descriptor,
3784
- backward_input_data, acc_type, scratch_allocator, algorithm_config,
3785
- algorithm_desc, scratch_memory),
3786
- /* report_error=*/ true );
3787
- }
3788
-
3789
3678
bool CudnnSupport::DoConvolveBackwardData (
3790
3679
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
3791
3680
const DeviceMemory<double >& filter_data,
@@ -3846,36 +3735,6 @@ bool CudnnSupport::DoConvolveBackwardData(
3846
3735
/* report_error=*/ !output_profile_result);
3847
3736
}
3848
3737
3849
- template <class T >
3850
- port::Status CudnnSupport::PrepareForConvolutionBackwardFilterImpl (
3851
- Stream* stream, const dnn::BatchDescriptor& input_descriptor,
3852
- const DeviceMemory<T>& input_data,
3853
- const dnn::BatchDescriptor& output_descriptor,
3854
- DeviceMemory<T> backward_output_data,
3855
- const dnn::ConvolutionDescriptor& convolution_descriptor,
3856
- const dnn::FilterDescriptor& filter_descriptor,
3857
- DeviceMemory<T>* backward_filter_data, dnn::DataType accumulator_type,
3858
- ScratchAllocator* scratch_allocator,
3859
- const dnn::AlgorithmConfig& algorithm_config,
3860
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
3861
- cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
3862
- auto cudnn = cudnn_->GetHandle (parent_, stream);
3863
-
3864
- CudnnTensorDescriptor out_back_nd (output_descriptor, cudnn_type);
3865
- CudnnTensorDescriptor input_nd (input_descriptor, cudnn_type);
3866
- CudnnFilterDescriptor filter (filter_descriptor, cudnn_type);
3867
- CudnnConvolutionDescriptor conv (convolution_descriptor,
3868
- ToCudnnDataType (accumulator_type));
3869
-
3870
- SE_ASSIGN_OR_RETURN (
3871
- *algorithm_desc,
3872
- GetCudnnConvolutionBackwardFilterAlgorithm (
3873
- stream, cudnn, algorithm_config, input_nd, filter, conv, out_back_nd,
3874
- scratch_allocator, scratch_memory));
3875
-
3876
- return port::Status::OK ();
3877
- }
3878
-
3879
3738
template <class T >
3880
3739
port::Status CudnnSupport::DoConvolveBackwardFilterImpl (
3881
3740
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
@@ -4013,70 +3872,6 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
4013
3872
return port::Status::OK ();
4014
3873
}
4015
3874
4016
- bool CudnnSupport::PrepareForConvolutionBackwardFilter (
4017
- Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4018
- const DeviceMemory<double >& input_data,
4019
- const dnn::BatchDescriptor& output_descriptor,
4020
- DeviceMemory<double > backward_output_data,
4021
- const dnn::ConvolutionDescriptor& convolution_descriptor,
4022
- const dnn::FilterDescriptor& filter_descriptor,
4023
- DeviceMemory<double >* backward_filter_data,
4024
- ScratchAllocator* scratch_allocator,
4025
- const dnn::AlgorithmConfig& algorithm_config,
4026
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
4027
- return IsStatusOk (
4028
- PrepareForConvolutionBackwardFilterImpl (
4029
- stream, input_descriptor, input_data, output_descriptor,
4030
- backward_output_data, convolution_descriptor, filter_descriptor,
4031
- backward_filter_data, dnn::DataType::kDouble , scratch_allocator,
4032
- algorithm_config, algorithm_desc, scratch_memory),
4033
- /* report_error=*/ true );
4034
- }
4035
-
4036
- bool CudnnSupport::PrepareForConvolutionBackwardFilter (
4037
- Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4038
- const DeviceMemory<float >& input_data,
4039
- const dnn::BatchDescriptor& output_descriptor,
4040
- DeviceMemory<float > backward_output_data,
4041
- const dnn::ConvolutionDescriptor& convolution_descriptor,
4042
- const dnn::FilterDescriptor& filter_descriptor,
4043
- DeviceMemory<float >* backward_filter_data,
4044
- ScratchAllocator* scratch_allocator,
4045
- const dnn::AlgorithmConfig& algorithm_config,
4046
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
4047
- return IsStatusOk (
4048
- PrepareForConvolutionBackwardFilterImpl (
4049
- stream, input_descriptor, input_data, output_descriptor,
4050
- backward_output_data, convolution_descriptor, filter_descriptor,
4051
- backward_filter_data, dnn::DataType::kFloat , scratch_allocator,
4052
- algorithm_config, algorithm_desc, scratch_memory),
4053
- /* report_error=*/ true );
4054
- }
4055
-
4056
- bool CudnnSupport::PrepareForConvolutionBackwardFilter (
4057
- Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4058
- const DeviceMemory<Eigen::half>& input_data,
4059
- const dnn::BatchDescriptor& output_descriptor,
4060
- DeviceMemory<Eigen::half> backward_output_data,
4061
- const dnn::ConvolutionDescriptor& convolution_descriptor,
4062
- const dnn::FilterDescriptor& filter_descriptor,
4063
- DeviceMemory<Eigen::half>* backward_filter_data,
4064
- ScratchAllocator* scratch_allocator,
4065
- const dnn::AlgorithmConfig& algorithm_config,
4066
- dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) {
4067
- dnn::DataType acc_type =
4068
- CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled ()
4069
- ? dnn::DataType::kFloat
4070
- : dnn::DataType::kHalf ;
4071
- return IsStatusOk (
4072
- PrepareForConvolutionBackwardFilterImpl (
4073
- stream, input_descriptor, input_data, output_descriptor,
4074
- backward_output_data, convolution_descriptor, filter_descriptor,
4075
- backward_filter_data, acc_type, scratch_allocator, algorithm_config,
4076
- algorithm_desc, scratch_memory),
4077
- /* report_error=*/ true );
4078
- }
4079
-
4080
3875
bool CudnnSupport::DoConvolveBackwardFilter (
4081
3876
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4082
3877
const DeviceMemory<double >& input_data,
0 commit comments