@@ -24,25 +24,78 @@ namespace operators {
2424using Tensor = framework::Tensor;
2525
2626template <typename T>
27- __global__ void CrossEntropyKernel (T* out, const T* softmax_out,
28- const int * label, const int batch_size,
29- const int class_num) {
27+ __global__ void CrossEntropy (T* out, const T* softmax_out, const int * labels,
28+ const int batch_size, const int class_num) {
3029 int i = blockIdx .x * blockDim .x + threadIdx .x ;
3130 if (i < batch_size) {
32- PADDLE_ASSERT (label[i] >= 0 && label[i] < class_num);
33- out[i] = -tolerable_value (std::log (softmax_out[i * class_num + label[i]]));
31+ PADDLE_ASSERT (labels[i] >= 0 && labels[i] < class_num);
32+ out[i] =
33+ -TolerableValue<T>()(std::log (softmax_out[i * class_num + labels[i]]));
3434 }
3535}
3636
3737template <typename T>
38- __global__ void CrossEntropyWithSoftmaxGradKernel (T* softmax_out,
39- const int * label,
40- const int batch_size,
41- const int class_num) {
42- int i = blockIdx .x * blockDim .x + threadIdx .x ;
43- if (i < batch_size) {
44- PADDLE_ASSERT (label[i] >= 0 && label[i] < class_num);
45- softmax_out[i * class_num + label[i]] -= 1 .;
38+ __global__ void CrossEntropyGrad (T* out_grad, const T* in_grad,
39+ const int * labels, const int batch_size,
40+ const int class_num) {
41+ int tid = blockIdx .x * blockDim .x + threadIdx .x ;
42+ int sample_idx = tid / class_num;
43+
44+ if (tid < batch_size * class_num) out_grad[tid] *= in_grad[sample_idx];
45+ __syncthreads ();
46+
47+ if (tid < batch_size) {
48+ PADDLE_ASSERT (labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
49+ out_grad[tid * class_num + labels[tid]] -= 1 .;
50+ }
51+ }
52+
53+ template <typename T>
54+ __device__ __forceinline__ T sum_single_warp (T val) {
55+ val += __shfl_down (val, 16 );
56+ val += __shfl_down (val, 8 );
57+ val += __shfl_down (val, 4 );
58+ val += __shfl_down (val, 2 );
59+ val += __shfl_down (val, 1 );
60+ return val;
61+ }
62+
63+ template <typename T>
64+ __global__ void SoftCrossEntropyKernel (T* Y, const T* X, const T* label,
65+ const int class_num) {
66+ int tid = threadIdx .x ;
67+ extern __shared__ T d_sum[];
68+ d_sum[tid] = 0 ;
69+
70+ int cur_idx = tid;
71+ int next_idx = blockIdx .x * class_num + tid;
72+ while (cur_idx < class_num) {
73+ d_sum[tid] += TolerableValue<T>()(std::log (X[next_idx])) * label[next_idx];
74+ next_idx += blockDim .x ;
75+ cur_idx += blockDim .x ;
76+ }
77+ __syncthreads ();
78+
79+ for (unsigned int stride = blockDim .x >> 1 ; stride >= 32 ; stride >>= 1 ) {
80+ if (tid < stride) d_sum[tid] += d_sum[tid + stride];
81+ __syncthreads ();
82+ }
83+
84+ T val = d_sum[tid];
85+ val = sum_single_warp<T>(val);
86+ if (tid == 0 ) Y[blockIdx .x ] = -val;
87+ }
88+
89+ template <typename T>
90+ __global__ void SoftCrossEntropyGradientKernel (T* logit_grad,
91+ const T* loss_grad,
92+ const T* labels,
93+ const int batch_size,
94+ const int class_num) {
95+ int ids = blockIdx .x * blockDim .x + threadIdx .x ;
96+ if (ids < batch_size * class_num) {
97+ int row_ids = ids / class_num;
98+ logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids];
4699 }
47100}
48101
@@ -52,27 +105,36 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
52105 void Compute (const framework::ExecutionContext& context) const override {
53106 PADDLE_ENFORCE (platform::is_gpu_place (context.GetPlace ()),
54107 " This kernel only runs on GPU device." );
108+ T* loss_data =
109+ context.Output <Tensor>(" Loss" )->mutable_data <T>(context.GetPlace ());
55110
56- // Calculate ths softmax outputs.
57111 const Tensor* logits = context.Input <Tensor>(" Logits" );
58112 Tensor* softmax = context.Output <Tensor>(" Softmax" );
59- softmax->mutable_data <T>(context.GetPlace ());
60- math::SoftmaxFunctor<platform::GPUPlace, T>()(logits, softmax, context);
61- T* softmax_out = softmax->data <T>();
62-
63- // Calculate the cross entropy loss based on hard labels.
64- const int * label_data = context.Input <Tensor>(" Label" )->data <int >();
65- Tensor* loss = context.Output <Tensor>(" Loss" );
66- loss->mutable_data <T>(context.GetPlace ());
67- T* loss_data = loss->data <T>();
113+ T* softmax_out = softmax->mutable_data <T>(context.GetPlace ());
114+ math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax);
68115
69116 const int batch_size = logits->dims ()[0 ];
70117 const int class_num = logits->dims ()[1 ];
71118 int block = 512 ;
72119 int grid = (batch_size + block - 1 ) / block;
73120
74- CrossEntropyKernel<T><<<grid, block>>> (loss_data, softmax_out, label_data,
75- batch_size, class_num);
121+ if (context.Attr <bool >(" softLabel" )) {
122+ const T* label_data = context.Input <Tensor>(" Label" )->data <T>();
123+ block = class_num > 512 ? 512 : pow (2 , int (std::log2 (class_num)));
124+
125+ SoftCrossEntropyKernel<
126+ T><<<batch_size, block, block * sizeof (T),
127+ reinterpret_cast <const platform::CUDADeviceContext&>(
128+ context.device_context())
129+ .stream()>>> (loss_data, softmax_out, label_data, class_num);
130+ } else {
131+ const int * label_data = context.Input <Tensor>(" Label" )->data <int >();
132+ CrossEntropy<T><<<grid, block, 0 ,
133+ reinterpret_cast <const platform::CUDADeviceContext&>(
134+ context.device_context())
135+ .stream()>>> (loss_data, softmax_out, label_data,
136+ batch_size, class_num);
137+ }
76138 }
77139};
78140
@@ -82,22 +144,34 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel {
82144 void Compute (const framework::ExecutionContext& context) const override {
83145 PADDLE_ENFORCE (platform::is_gpu_place (context.GetPlace ()),
84146 " This kernel only runs on GPU device." );
85-
147+ const Tensor* labels = context.Input <Tensor>(" Label" );
148+ const T* loss_grad_data =
149+ context.Input <Tensor>(framework::GradVarName (" Loss" ))->data <T>();
86150 Tensor* logit_grad =
87151 context.Output <Tensor>(framework::GradVarName (" Logits" ));
88152 logit_grad->ShareDataWith <T>(*context.Input <Tensor>(" Softmax" ));
89153 T* logit_grad_data = logit_grad->data <T>();
90154
91155 const int batch_size = logit_grad->dims ()[0 ];
92156 const int class_num = logit_grad->dims ()[1 ];
93-
94- const int * label_data = context.Input <Tensor>(" Label" )->data <int >();
95-
96- const int block = 512 ;
97- const int grid = (batch_size + block - 1 ) / block;
98-
99- CrossEntropyWithSoftmaxGradKernel<T><<<grid, block>>> (
100- logit_grad_data, label_data, batch_size, class_num);
157+ int block = 512 ;
158+ int grid = (batch_size * class_num + block - 1 ) / block;
159+
160+ if (context.Attr <bool >(" softLabel" )) {
161+ const T* label_data = labels->data <T>();
162+ SoftCrossEntropyGradientKernel<T><<<
163+ grid, block, 0 , reinterpret_cast <const platform::CUDADeviceContext&>(
164+ context.device_context())
165+ .stream()>>> (logit_grad_data, loss_grad_data,
166+ label_data, batch_size, class_num);
167+ } else {
168+ const int * label_data = labels->data <int >();
169+ CrossEntropyGrad<T><<<
170+ grid, block, 0 , reinterpret_cast <const platform::CUDADeviceContext&>(
171+ context.device_context())
172+ .stream()>>> (logit_grad_data, loss_grad_data,
173+ label_data, batch_size, class_num);
174+ }
101175 }
102176};
103177
0 commit comments